Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 87922d994f |
-10
@@ -1,15 +1,5 @@
|
|||||||
.zed
|
|
||||||
.idea
|
.idea
|
||||||
**/.vscode
|
**/.vscode
|
||||||
|
|
||||||
.claude
|
.claude
|
||||||
.codex
|
|
||||||
AGENTS.md
|
AGENTS.md
|
||||||
|
|
||||||
CMakeUserPresets.json
|
|
||||||
|
|
||||||
build
|
build
|
||||||
cmake-build-debug
|
|
||||||
cmake-build-release
|
|
||||||
|
|
||||||
**/__*
|
|
||||||
|
|||||||
@@ -1,159 +1,5 @@
|
|||||||
# Raptor
|
# Raptor
|
||||||
|
|
||||||
Raptor is a domain-specific MLIR compiler for neural networks (ONNX format)
|
|
||||||
targeting in-memory computing / processing-in-memory (PIM) architectures.
|
|
||||||
It progressively lowers ONNX-MLIR through a set of MLIR dialects down to
|
|
||||||
target-specific artifacts (currently JSON code for the `pimsim-nn` simulator).
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
PIM architectures perform most of the computation directly in memory.
|
|
||||||
Raptor's first supported target is `pimsim-nn`, which simulates a chip with:
|
|
||||||
- a shared host memory,
|
|
||||||
- a number of cores that do most of the computation directly in their memory
|
|
||||||
(vector ops, vmm/mvm on ReRAM crossbars),
|
|
||||||
- no branching instructions (branchless architecture) and no hardware loop
|
|
||||||
support — any repeated work (e.g. convolutions) must be unrolled into
|
|
||||||
explicit per-iteration instructions.
|
|
||||||
|
|
||||||
Because of this, the amount of emitted instructions explodes quickly and the
|
|
||||||
compiler must optimize aggressively at every stage to keep compilation
|
|
||||||
tractable.
|
|
||||||
|
|
||||||
A second target, `PulPim`, is planned for an accelerator with RISC-V cores
|
|
||||||
each carrying its own in-memory computing unit and crossbars. It will live in
|
|
||||||
a dedicated dialect (future work).
|
|
||||||
|
|
||||||
### Targets and simulators
|
|
||||||
|
|
||||||
`pimsim-nn` (under `backend-simulators/pim/pimsim-nn`) is used for
|
|
||||||
**performance** estimates (latency, energy), but does not functionally execute
|
|
||||||
the JSON code it consumes. To validate the numerical correctness of the JSON
|
|
||||||
code produced by Raptor (or, for comparison, by the `pimcomp` compiler), we use
|
|
||||||
a Rust simulator we maintain in-tree at
|
|
||||||
`backend-simulators/pim/pim-simulator`.
|
|
||||||
|
|
||||||
## Compilation pipeline
|
|
||||||
|
|
||||||
The PIM-related sources live under `src/PIM` and the tests under `test/PIM`.
|
|
||||||
When working on this codebase, most changes should stay confined to those
|
|
||||||
trees (you only need to look outside, e.g. at `onnx-mlir` or `llvm`, for
|
|
||||||
framework-level details).
|
|
||||||
|
|
||||||
High-level lowering flow:
|
|
||||||
|
|
||||||
```
|
|
||||||
ONNX-MLIR ──► Spatial ──► Pim (tensor) ──► Pim (bufferized) ──► PIM JSON
|
|
||||||
```
|
|
||||||
|
|
||||||
1. **ONNX → Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
|
|
||||||
Lowers ONNX ops into the `spat` dialect (`src/PIM/Dialect/Spatial`).
|
|
||||||
Spatial models a high-level spatial in-memory accelerator: vmm/mvm
|
|
||||||
operations are accelerated by storing a constant RHS matrix into a
|
|
||||||
crossbar. Crossbars cannot be re-programmed during execution, have a
|
|
||||||
limited fixed size, and there is a limited number of them per core.
|
|
||||||
Conversion patterns are split by op family under
|
|
||||||
`Conversion/ONNXToSpatial/Patterns/{Math,NN,Tensor}` (Conv, Gemm, MatMul,
|
|
||||||
Elementwise, ReduceMean, Pool, Relu, Sigmoid, Softmax, Concat, Gather,
|
|
||||||
Reshape, Resize, Split).
|
|
||||||
|
|
||||||
2. **Spatial → Pim** (`src/PIM/Conversion/SpatialToPim`).
|
|
||||||
Lowers Spatial to the `pim` dialect (`src/PIM/Dialect/Pim`), which
|
|
||||||
materializes PIM cores (`pim.core`), inter-core communication
|
|
||||||
(`pim.send` / `pim.receive`), halts, and crossbar-level operations.
|
|
||||||
|
|
||||||
3. **Merge compute nodes** (`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
|
|
||||||
A DCP-inspired heuristic (Dynamic Critical Path — see the original
|
|
||||||
scheduling paper by Kwok & Ahmad,
|
|
||||||
[DCP-eScience2007](https://clouds.cis.unimelb.edu.au/papers/DCP-eScience2007.pdf))
|
|
||||||
that coarsens the virtual node graph and decides how to group compute
|
|
||||||
nodes onto cores. Our implementation is only DCP-*inspired*: it is a
|
|
||||||
heuristic with different assumptions from the paper (different cost
|
|
||||||
model, constraints from crossbar capacity / core resources, and a
|
|
||||||
windowed coarsening loop instead of full-graph reprioritization). The
|
|
||||||
`dcp-critical-window-size` option controls how many lowest-slack virtual
|
|
||||||
nodes each coarsening iteration considers (0 = legacy full-graph
|
|
||||||
analysis). Related sources: `DCPGraph/DCPAnalysis.cpp`, `Graph.cpp/.hpp`,
|
|
||||||
`MergeComputeNodesPass.cpp`.
|
|
||||||
|
|
||||||
4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`).
|
|
||||||
Converts tensor-semantics PIM IR into memref-semantics PIM IR using the
|
|
||||||
standard MLIR `BufferizableOpInterface` machinery
|
|
||||||
(`OpBufferizationInterfaces.*`, `PimBufferization.td`).
|
|
||||||
|
|
||||||
5. **PIM code generation** (`src/PIM/Pass/PimCodegen`):
|
|
||||||
- `HostConstantFolding` — folds host-side constants.
|
|
||||||
- `MaterializeHostConstantsPass` — materializes the remaining host
|
|
||||||
constants for emission.
|
|
||||||
- `VerificationPass` — checks invariants before emission.
|
|
||||||
- `EmitPimJsonPass` — emits the final PIM JSON consumed by `pimsim-nn`
|
|
||||||
and `pim-simulator`.
|
|
||||||
|
|
||||||
Supporting pieces:
|
|
||||||
- `src/PIM/Compiler` — PIM-specific compiler options (crossbar size/count,
|
|
||||||
core count, DCP window, experimental conv impl, concat error handling, …)
|
|
||||||
and `PimCodeGen` entry points.
|
|
||||||
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`).
|
|
||||||
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`)
|
|
||||||
and the `PIMPasses.h` registry used by `PimAccelerator`.
|
|
||||||
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers
|
|
||||||
dialects, passes, and plugs Raptor into the ONNX-MLIR driver.
|
|
||||||
|
|
||||||
## Key compiler options
|
|
||||||
|
|
||||||
Pass these on the `onnx-mlir` command line when compiling for PIM:
|
|
||||||
|
|
||||||
- `--maccel=PIM` — select the PIM accelerator.
|
|
||||||
- `--EmitSpatial` / `--EmitPim` / `--EmitPimBufferized` / `--EmitPimCodegen`
|
|
||||||
— stop the pipeline at the requested stage (default: `EmitPimCodegen`).
|
|
||||||
- `--pim-only-codegen` — assume the input is already bufferized PIM IR and
|
|
||||||
run only the codegen tail.
|
|
||||||
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
|
|
||||||
per-core count.
|
|
||||||
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
|
|
||||||
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
|
|
||||||
- `--use-experimental-conv-impl` — alternative convolution lowering.
|
|
||||||
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
|
|
||||||
|
|
||||||
## Validation
|
|
||||||
|
|
||||||
Functional validation lives in `validation/` and drives the Rust
|
|
||||||
`pim-simulator` to compare Raptor's output against a reference.
|
|
||||||
|
|
||||||
Per-operation validation (from `validation/`):
|
|
||||||
|
|
||||||
```
|
|
||||||
validate.py \
|
|
||||||
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
|
||||||
--onnx-include-dir ../onnx-mlir/include
|
|
||||||
```
|
|
||||||
|
|
||||||
End-to-end network validation (example: first 4 layers of YOLOv11n):
|
|
||||||
|
|
||||||
```
|
|
||||||
validate.py \
|
|
||||||
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
|
||||||
--onnx-include-dir ../onnx-mlir/include \
|
|
||||||
--operations-dir ./networks/yolo11n/depth_04 \
|
|
||||||
--crossbar-size 2048 --crossbar-count 256
|
|
||||||
```
|
|
||||||
|
|
||||||
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
|
|
||||||
Available operations under `validation/operations/`: `add`, `conv`, `div`,
|
|
||||||
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
|
|
||||||
`sigmoid`, `softmax`, `split`.
|
|
||||||
|
|
||||||
## Rebuilding
|
|
||||||
|
|
||||||
Release build (fast):
|
|
||||||
|
|
||||||
```
|
|
||||||
cmake --build /home/nico/raptor/raptor/cmake-build-release --target onnx-mlir -j 30
|
|
||||||
```
|
|
||||||
|
|
||||||
A slower debug build is also available — configure it the same way but with
|
|
||||||
`-DCMAKE_BUILD_TYPE=Debug` (see installation instructions below).
|
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
### Protobuf
|
### Protobuf
|
||||||
|
|||||||
@@ -55,23 +55,15 @@ pub trait HasSigm {
|
|||||||
|
|
||||||
impl HasSigm for f32 {
|
impl HasSigm for f32 {
|
||||||
fn sigm(self) -> Self {
|
fn sigm(self) -> Self {
|
||||||
if self >= 0.0 {
|
let ex = self.exp();
|
||||||
1.0 / (1.0 + (-self).exp())
|
ex / (1.0 + ex)
|
||||||
} else {
|
|
||||||
let ex = self.exp();
|
|
||||||
ex / (1.0 + ex)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HasSigm for f64 {
|
impl HasSigm for f64 {
|
||||||
fn sigm(self) -> Self {
|
fn sigm(self) -> Self {
|
||||||
if self >= 0.0 {
|
let ex = self.exp();
|
||||||
1.0 / (1.0 + (-self).exp())
|
ex / (1.0 + ex)
|
||||||
} else {
|
|
||||||
let ex = self.exp();
|
|
||||||
ex / (1.0 + ex)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,5 @@
|
|||||||
add_pim_library(OMPimCommon
|
add_pim_library(OMPimCommon
|
||||||
IR/AddressAnalysis.cpp
|
PimCommon.cpp
|
||||||
IR/CoreBlockUtils.cpp
|
|
||||||
IR/EntryPointUtils.cpp
|
|
||||||
IR/ShapeUtils.cpp
|
|
||||||
IR/WeightUtils.cpp
|
|
||||||
Support/DebugDump.cpp
|
|
||||||
Support/Diagnostics.cpp
|
|
||||||
Support/FileSystemUtils.cpp
|
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -1,258 +0,0 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
||||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) {
|
|
||||||
if (!moduleOp || !getGlobalOp)
|
|
||||||
return {};
|
|
||||||
return moduleOp.lookupSymbol<mlir::memref::GlobalOp>(getGlobalOp.getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
|
||||||
if (!knowledge)
|
|
||||||
return value;
|
|
||||||
|
|
||||||
auto iter = knowledge->aliases.find(value);
|
|
||||||
while (iter != knowledge->aliases.end()) {
|
|
||||||
value = iter->second;
|
|
||||||
iter = knowledge->aliases.find(value);
|
|
||||||
}
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
|
||||||
value = resolveAlias(value, knowledge);
|
|
||||||
|
|
||||||
if (mlir::isa<mlir::BlockArgument>(value))
|
|
||||||
return value;
|
|
||||||
|
|
||||||
mlir::Operation* definingOp = value.getDefiningOp();
|
|
||||||
if (!definingOp)
|
|
||||||
return value;
|
|
||||||
|
|
||||||
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
|
||||||
if (auto result = mlir::dyn_cast<mlir::OpResult>(value))
|
|
||||||
if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
|
|
||||||
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
|
|
||||||
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
|
||||||
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
|
|
||||||
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
|
|
||||||
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp))
|
|
||||||
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
|
|
||||||
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
|
||||||
|
|
||||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
|
||||||
value = resolveAlias(value, knowledge);
|
|
||||||
|
|
||||||
if (knowledge) {
|
|
||||||
auto iter = knowledge->indexValues.find(value);
|
|
||||||
if (iter != knowledge->indexValues.end())
|
|
||||||
return iter->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>();
|
|
||||||
if (constantOp) {
|
|
||||||
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue()))
|
|
||||||
return integerAttr.getInt();
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::Operation* definingOp = value.getDefiningOp();
|
|
||||||
if (!definingOp)
|
|
||||||
return mlir::failure();
|
|
||||||
|
|
||||||
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
|
|
||||||
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
|
|
||||||
|
|
||||||
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs))
|
|
||||||
return mlir::failure();
|
|
||||||
return *lhs + *rhs;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs))
|
|
||||||
return mlir::failure();
|
|
||||||
return *lhs - *rhs;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs))
|
|
||||||
return mlir::failure();
|
|
||||||
return *lhs * *rhs;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
|
||||||
return mlir::failure();
|
|
||||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
|
||||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
|
||||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
|
||||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
|
||||||
return mlir::failure();
|
|
||||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
|
||||||
}
|
|
||||||
|
|
||||||
return mlir::failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
|
|
||||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(ofr)) {
|
|
||||||
auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
|
|
||||||
if (!integerAttr)
|
|
||||||
return mlir::failure();
|
|
||||||
return integerAttr.getInt();
|
|
||||||
}
|
|
||||||
|
|
||||||
return resolveIndexValueImpl(mlir::cast<mlir::Value>(ofr), knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Value value,
|
|
||||||
const StaticValueKnowledge* knowledge) {
|
|
||||||
int64_t byteOffset = 0;
|
|
||||||
value = resolveAlias(value, knowledge);
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
if (mlir::isa<mlir::BlockArgument>(value))
|
|
||||||
return ResolvedContiguousAddress {value, byteOffset};
|
|
||||||
|
|
||||||
mlir::Operation* definingOp = value.getDefiningOp();
|
|
||||||
if (!definingOp)
|
|
||||||
return mlir::failure();
|
|
||||||
|
|
||||||
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
|
||||||
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
|
|
||||||
if (!tiedOperand)
|
|
||||||
return mlir::failure();
|
|
||||||
value = resolveAlias(tiedOperand->get(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
|
|
||||||
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
|
||||||
if (!result)
|
|
||||||
return mlir::failure();
|
|
||||||
|
|
||||||
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
|
||||||
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
|
||||||
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
|
||||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
|
||||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
|
||||||
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
value = yieldedValue;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
|
||||||
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
|
||||||
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
|
||||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
|
||||||
return mlir::failure();
|
|
||||||
|
|
||||||
llvm::SmallVector<int64_t> offsets;
|
|
||||||
llvm::SmallVector<int64_t> sizes;
|
|
||||||
llvm::SmallVector<int64_t> strides;
|
|
||||||
offsets.reserve(subviewOp.getMixedOffsets().size());
|
|
||||||
sizes.reserve(subviewOp.getMixedSizes().size());
|
|
||||||
strides.reserve(subviewOp.getMixedStrides().size());
|
|
||||||
|
|
||||||
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
|
||||||
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
|
|
||||||
if (failed(resolvedOffset))
|
|
||||||
return mlir::failure();
|
|
||||||
offsets.push_back(*resolvedOffset);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
|
||||||
auto resolvedSize = resolveOpFoldResult(size, knowledge);
|
|
||||||
if (failed(resolvedSize))
|
|
||||||
return mlir::failure();
|
|
||||||
sizes.push_back(*resolvedSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
|
||||||
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
|
|
||||||
if (failed(resolvedStride))
|
|
||||||
return mlir::failure();
|
|
||||||
strides.push_back(*resolvedStride);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
|
||||||
return mlir::failure();
|
|
||||||
|
|
||||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
|
||||||
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
|
||||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
|
|
||||||
value = resolveAlias(castOp.getSource(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
|
|
||||||
value = resolveAlias(collapseOp.getSrc(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
|
|
||||||
value = resolveAlias(expandOp.getSrc(), knowledge);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
|
|
||||||
return ResolvedContiguousAddress {value, byteOffset};
|
|
||||||
|
|
||||||
return mlir::failure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
|
|
||||||
|
|
||||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
|
||||||
return resolveIndexValueImpl(value, &knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
|
|
||||||
return resolveContiguousAddressImpl(value, nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
|
||||||
const StaticValueKnowledge& knowledge) {
|
|
||||||
return resolveContiguousAddressImpl(value, &knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
|
||||||
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
/// Describes a value as a base addressable object plus a statically known
|
|
||||||
/// byte offset after peeling aliases, casts, and contiguous subviews.
|
|
||||||
struct ResolvedContiguousAddress {
|
|
||||||
mlir::Value base;
|
|
||||||
int64_t byteOffset = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Records compile-time facts used when interpreting address arithmetic and
|
|
||||||
/// loop-carried aliases inside PIM regions.
|
|
||||||
struct StaticValueKnowledge {
|
|
||||||
llvm::DenseMap<mlir::Value, int64_t> indexValues;
|
|
||||||
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
|
|
||||||
|
|
||||||
StaticValueKnowledge() {}
|
|
||||||
};
|
|
||||||
|
|
||||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
|
||||||
|
|
||||||
/// Resolves a value to contiguous backing storage when that storage can be
|
|
||||||
/// proven statically from aliases, DPS ties, casts, and subviews.
|
|
||||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
|
||||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
|
||||||
const StaticValueKnowledge& knowledge);
|
|
||||||
|
|
||||||
/// Statically evaluates index-like SSA values, including simple integer
|
|
||||||
/// arithmetic and loop facts recorded in `knowledge`.
|
|
||||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
|
|
||||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
|
|
||||||
|
|
||||||
/// Follows alias, view, and DPS chains to recover the backing value of a
|
|
||||||
/// loop-carried memref/result.
|
|
||||||
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
bool isCoreStaticAddressOp(mlir::Operation* op) {
|
|
||||||
return mlir::isa<mlir::arith::ConstantOp,
|
|
||||||
mlir::arith::AddIOp,
|
|
||||||
mlir::arith::SubIOp,
|
|
||||||
mlir::arith::MulIOp,
|
|
||||||
mlir::arith::DivUIOp,
|
|
||||||
mlir::arith::RemUIOp,
|
|
||||||
mlir::arith::IndexCastOp,
|
|
||||||
mlir::memref::AllocOp,
|
|
||||||
mlir::memref::SubViewOp,
|
|
||||||
mlir::memref::CastOp,
|
|
||||||
mlir::memref::CollapseShapeOp,
|
|
||||||
mlir::memref::ExpandShapeOp>(op);
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::LogicalResult
|
|
||||||
walkPimCoreBlock(mlir::Block& block,
|
|
||||||
const StaticValueKnowledge& knowledge,
|
|
||||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
|
|
||||||
bool hasFailure = false;
|
|
||||||
for (mlir::Operation& op : block) {
|
|
||||||
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
|
|
||||||
mlir::Block& loopBody = forOp.getRegion().front();
|
|
||||||
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
|
||||||
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
|
||||||
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
|
||||||
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
|
|
||||||
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
|
|
||||||
hasFailure = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::SmallVector<mlir::Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
|
|
||||||
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
|
|
||||||
StaticValueKnowledge loopKnowledge = knowledge;
|
|
||||||
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
|
||||||
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
|
|
||||||
loopKnowledge.aliases[iterArg] = iterValue;
|
|
||||||
|
|
||||||
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
|
|
||||||
hasFailure = true;
|
|
||||||
|
|
||||||
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(loopBody.getTerminator());
|
|
||||||
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
|
|
||||||
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (failed(callback(op, knowledge)))
|
|
||||||
hasFailure = true;
|
|
||||||
}
|
|
||||||
return mlir::success(!hasFailure);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/Block.h"
|
|
||||||
#include "mlir/Support/LogicalResult.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
/// Returns true for ops in a `pim.core` body that only participate in static
|
|
||||||
/// address or index computation and therefore do not emit PIM instructions.
|
|
||||||
bool isCoreStaticAddressOp(mlir::Operation* op);
|
|
||||||
|
|
||||||
/// Walks a `pim.core` body, statically unrolling nested `scf.for` loops when
|
|
||||||
/// their bounds are known and invoking `callback` only on instruction-emitting
|
|
||||||
/// operations.
|
|
||||||
mlir::LogicalResult
|
|
||||||
walkPimCoreBlock(mlir::Block& block,
|
|
||||||
const StaticValueKnowledge& knowledge,
|
|
||||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp) {
|
|
||||||
if (!moduleOp)
|
|
||||||
return mlir::failure();
|
|
||||||
|
|
||||||
llvm::SmallVector<mlir::ONNXEntryPointOp> entryPoints(moduleOp.getOps<mlir::ONNXEntryPointOp>());
|
|
||||||
if (entryPoints.size() > 1) {
|
|
||||||
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
|
||||||
return mlir::failure();
|
|
||||||
}
|
|
||||||
if (!entryPoints.empty()) {
|
|
||||||
auto entryPointAttr =
|
|
||||||
entryPoints.front()->getAttrOfType<mlir::SymbolRefAttr>(mlir::ONNXEntryPointOp::getEntryPointFuncAttrName());
|
|
||||||
if (!entryPointAttr) {
|
|
||||||
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
|
||||||
return mlir::failure();
|
|
||||||
}
|
|
||||||
auto entryFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
|
||||||
if (!entryFunc) {
|
|
||||||
entryPoints.front().emitOpError("references an unknown entry function ")
|
|
||||||
<< entryPointAttr.getLeafReference().getValue();
|
|
||||||
return mlir::failure();
|
|
||||||
}
|
|
||||||
return entryFunc;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto mainGraphFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>("main_graph"))
|
|
||||||
return mainGraphFunc;
|
|
||||||
|
|
||||||
llvm::SmallVector<mlir::func::FuncOp> nonExternalFuncs;
|
|
||||||
for (auto funcOp : moduleOp.getOps<mlir::func::FuncOp>())
|
|
||||||
if (!funcOp.isExternal())
|
|
||||||
nonExternalFuncs.push_back(funcOp);
|
|
||||||
if (nonExternalFuncs.size() == 1)
|
|
||||||
return nonExternalFuncs.front();
|
|
||||||
|
|
||||||
moduleOp.emitError("could not resolve a unique PIM entry function");
|
|
||||||
return mlir::failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
/// Resolves the function the PIM pipeline should treat as its entry point.
|
|
||||||
/// Prefers ONNX entry-point metadata, then `main_graph`, then the only
|
|
||||||
/// non-external function if the module is otherwise unambiguous.
|
|
||||||
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape) {
|
|
||||||
llvm::SmallVector<int64_t> strides(shape.size(), 1);
|
|
||||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
|
||||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
|
||||||
return strides;
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::SmallVector<int64_t>
|
|
||||||
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides) {
|
|
||||||
llvm::SmallVector<int64_t> indices(shape.size(), 0);
|
|
||||||
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
|
||||||
indices[dim] = linearIndex / stride;
|
|
||||||
linearIndex %= stride;
|
|
||||||
}
|
|
||||||
return indices;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides) {
|
|
||||||
int64_t linearIndex = 0;
|
|
||||||
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
|
||||||
linearIndex += index * stride;
|
|
||||||
return linearIndex;
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
|
|
||||||
int64_t numElements = 1;
|
|
||||||
for (int64_t dim : shape)
|
|
||||||
numElements *= dim;
|
|
||||||
return numElements;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
|
||||||
llvm::ArrayRef<int64_t> offsets,
|
|
||||||
llvm::ArrayRef<int64_t> sizes,
|
|
||||||
llvm::ArrayRef<int64_t> strides) {
|
|
||||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
|
||||||
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
||||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
||||||
|
|
||||||
auto firstNonZeroOffset = std::find_if(
|
|
||||||
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
|
||||||
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
|
||||||
return offset != 0;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
|
||||||
auto [offset, size, dimension] = *firstNonZeroOffset;
|
|
||||||
if (size > dimension - offset)
|
|
||||||
return false;
|
|
||||||
++firstNonZeroOffset;
|
|
||||||
|
|
||||||
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
|
||||||
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
|
||||||
return size != 1;
|
|
||||||
}))
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
|
||||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
|
||||||
|
|
||||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
|
||||||
auto [size, dimension] = sizeAndShape;
|
|
||||||
return size != dimension;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (firstDifferentSize != sizesAndShape.end()) {
|
|
||||||
++firstDifferentSize;
|
|
||||||
|
|
||||||
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
|
||||||
auto [size, _dimension] = sizeAndShape;
|
|
||||||
return size != 1;
|
|
||||||
}))
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
|
||||||
|
|
||||||
llvm::SmallVector<int64_t>
|
|
||||||
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
|
|
||||||
|
|
||||||
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
|
|
||||||
|
|
||||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
|
||||||
|
|
||||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
|
||||||
llvm::ArrayRef<int64_t> offsets,
|
|
||||||
llvm::ArrayRef<int64_t> sizes,
|
|
||||||
llvm::ArrayRef<int64_t> strides);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,101 +0,0 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
|
||||||
#include "llvm/ADT/SmallSet.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
|
||||||
|
|
||||||
void markWeightAlways(mlir::Operation* op) {
|
|
||||||
assert(op && "expected valid op");
|
|
||||||
op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext()));
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
|
||||||
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
|
||||||
bool found = false;
|
|
||||||
parentOp.walk([&](mlir::Operation* op) {
|
|
||||||
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
|
||||||
found |= mvmOp.getWeightIndex() == weightIndex;
|
|
||||||
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
|
||||||
found |= vmmOp.getWeightIndex() == weightIndex;
|
|
||||||
});
|
|
||||||
return found;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
|
||||||
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
|
||||||
auto weights = parentOp.getWeights();
|
|
||||||
llvm::SmallSet<unsigned, 8> visited;
|
|
||||||
auto walkWeightIndex = [&](unsigned weightIndex) {
|
|
||||||
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
|
|
||||||
callback(parentOp->getOpOperand(weightIndex));
|
|
||||||
};
|
|
||||||
|
|
||||||
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
|
||||||
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
|
|
||||||
mlir::Operation* user = use.getOwner();
|
|
||||||
unsigned operandIndex = use.getOperandNumber();
|
|
||||||
|
|
||||||
auto computeOp = mlir::dyn_cast<spatial::SpatCompute>(user);
|
|
||||||
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
|
||||||
return false;
|
|
||||||
|
|
||||||
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
|
||||||
llvm::SmallPtrSet<mlir::Value, 8> visited;
|
|
||||||
auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool {
|
|
||||||
if (!visited.insert(currentValue).second)
|
|
||||||
return true;
|
|
||||||
if (currentValue.use_empty())
|
|
||||||
return false;
|
|
||||||
|
|
||||||
return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) {
|
|
||||||
if (isSpatialMvmVmmWeightUse(use))
|
|
||||||
return true;
|
|
||||||
|
|
||||||
mlir::Operation* user = use.getOwner();
|
|
||||||
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(user))
|
|
||||||
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
|
|
||||||
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(user))
|
|
||||||
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
|
|
||||||
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
|
|
||||||
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
|
|
||||||
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user))
|
|
||||||
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
|
|
||||||
|
|
||||||
return false;
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
return walkUses(value, walkUses);
|
|
||||||
}
|
|
||||||
|
|
||||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
|
||||||
assert(root && "expected valid root op");
|
|
||||||
root->walk([&](pim::PimCoreOp coreOp) { walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback); });
|
|
||||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
|
||||||
auto weights = coreBatchOp.getWeights();
|
|
||||||
for (auto weight : weights)
|
|
||||||
for (mlir::OpOperand& use : weight.getUses())
|
|
||||||
if (use.getOwner() == coreBatchOp.getOperation())
|
|
||||||
callback(use);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
|
||||||
|
|
||||||
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
bool hasWeightAlways(mlir::Operation* op);
|
|
||||||
|
|
||||||
/// Tags an op as producing a value that should stay materialized as a reusable
|
|
||||||
/// weight across later PIM lowering/codegen stages.
|
|
||||||
void markWeightAlways(mlir::Operation* op);
|
|
||||||
|
|
||||||
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
|
|
||||||
|
|
||||||
/// Returns true when a value flows only into Spatial weighted MVM/VMM operands,
|
|
||||||
/// allowing later passes to preserve it as a dedicated weight-like object.
|
|
||||||
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
|
|
||||||
|
|
||||||
/// Visits weight operands consumed by Pim core ops/core batches so downstream
|
|
||||||
/// passes can identify globals that must remain weight-backed.
|
|
||||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -0,0 +1,546 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||||
|
|
||||||
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
|
#include <filesystem>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
std::string getOutputDir() {
|
||||||
|
if (outputBaseName.empty() || outputBaseName == "-")
|
||||||
|
return {};
|
||||||
|
|
||||||
|
size_t lastSlash = outputBaseName.find_last_of('/');
|
||||||
|
if (lastSlash == std::string::npos)
|
||||||
|
return ".";
|
||||||
|
return outputBaseName.substr(0, lastSlash);
|
||||||
|
}
|
||||||
|
|
||||||
|
void createDirectory(const std::string& directory) {
|
||||||
|
std::error_code errorCode;
|
||||||
|
std::filesystem::create_directories(directory, errorCode);
|
||||||
|
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
||||||
|
}
|
||||||
|
|
||||||
|
void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
||||||
|
std::string outputDir = getOutputDir();
|
||||||
|
if (outputDir.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
std::string dialectsDir = outputDir + "/dialects";
|
||||||
|
createDirectory(dialectsDir);
|
||||||
|
|
||||||
|
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||||
|
llvm::raw_os_ostream os(file);
|
||||||
|
os << *moduleOp;
|
||||||
|
os.flush();
|
||||||
|
file.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
|
||||||
|
if (!moduleOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
|
||||||
|
if (entryPoints.size() > 1) {
|
||||||
|
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!entryPoints.empty()) {
|
||||||
|
auto entryPointAttr =
|
||||||
|
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
|
||||||
|
if (!entryPointAttr) {
|
||||||
|
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
||||||
|
if (!entryFunc) {
|
||||||
|
entryPoints.front().emitOpError("references an unknown entry function ")
|
||||||
|
<< entryPointAttr.getLeafReference().getValue();
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return entryFunc;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
|
||||||
|
return mainGraphFunc;
|
||||||
|
|
||||||
|
SmallVector<func::FuncOp> nonExternalFuncs;
|
||||||
|
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
|
||||||
|
if (!funcOp.isExternal())
|
||||||
|
nonExternalFuncs.push_back(funcOp);
|
||||||
|
if (nonExternalFuncs.size() == 1)
|
||||||
|
return nonExternalFuncs.front();
|
||||||
|
|
||||||
|
moduleOp.emitError("could not resolve a unique PIM entry function");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
||||||
|
|
||||||
|
void markWeightAlways(Operation* op) {
|
||||||
|
assert(op && "expected valid op");
|
||||||
|
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
||||||
|
}
|
||||||
|
|
||||||
|
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
||||||
|
if (!moduleOp || !getGlobalOp)
|
||||||
|
return {};
|
||||||
|
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
||||||
|
|
||||||
|
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
||||||
|
if (!channelNewOp) {
|
||||||
|
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
// channelNewOp should have two users: `op` and a
|
||||||
|
// `ChannelSendOp`/`ChannelReceiveOp`
|
||||||
|
auto channelUsers = channelNewOp->getUsers();
|
||||||
|
auto usersIterator = channelUsers.begin();
|
||||||
|
auto firstUser = *usersIterator;
|
||||||
|
usersIterator++;
|
||||||
|
if (usersIterator == channelUsers.end()) {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||||
|
"only one found.");
|
||||||
|
channelNewOp->dump();
|
||||||
|
op->dump();
|
||||||
|
channelNewOp->getParentOp()->dump();
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto secondUser = *usersIterator;
|
||||||
|
usersIterator++;
|
||||||
|
if (usersIterator != channelUsers.end()) {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||||
|
"more than two found.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
Operation* notOpUser;
|
||||||
|
if (firstUser == op) {
|
||||||
|
notOpUser = secondUser;
|
||||||
|
}
|
||||||
|
else if (secondUser == op) {
|
||||||
|
notOpUser = firstUser;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||||
|
"and one of them must be me, but"
|
||||||
|
"none of them is actually me.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opIsReceive) {
|
||||||
|
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||||
|
"me, the other is not a ChannelSendOp.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return notOpUser;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
|
||||||
|
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||||
|
"me, the other is not a ChannelReceiveOp.");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return notOpUser;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||||
|
SmallVector<int64_t> strides(shape.size(), 1);
|
||||||
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||||
|
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
|
||||||
|
SmallVector<int64_t> indices(shape.size(), 0);
|
||||||
|
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
||||||
|
indices[dim] = linearIndex / stride;
|
||||||
|
linearIndex %= stride;
|
||||||
|
}
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
|
||||||
|
int64_t linearIndex = 0;
|
||||||
|
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
||||||
|
linearIndex += index * stride;
|
||||||
|
return linearIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t getNumElements(ArrayRef<int64_t> shape) {
|
||||||
|
int64_t numElements = 1;
|
||||||
|
for (int64_t dim : shape)
|
||||||
|
numElements *= dim;
|
||||||
|
return numElements;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
|
||||||
|
ArrayRef<int64_t> offsets,
|
||||||
|
ArrayRef<int64_t> sizes,
|
||||||
|
ArrayRef<int64_t> strides) {
|
||||||
|
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
||||||
|
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||||
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||||
|
|
||||||
|
auto firstNonZeroOffset = std::find_if(
|
||||||
|
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
||||||
|
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
||||||
|
return offset != 0;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
||||||
|
auto [offset, size, dimension] = *firstNonZeroOffset;
|
||||||
|
if (size > dimension - offset)
|
||||||
|
return false;
|
||||||
|
++firstNonZeroOffset;
|
||||||
|
|
||||||
|
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
||||||
|
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
||||||
|
return size != 1;
|
||||||
|
}))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||||
|
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||||
|
|
||||||
|
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||||
|
auto [size, dimension] = sizeAndShape;
|
||||||
|
return size != dimension;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (firstDifferentSize != sizesAndShape.end()) {
|
||||||
|
++firstDifferentSize;
|
||||||
|
|
||||||
|
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
||||||
|
auto [size, _dimension] = sizeAndShape;
|
||||||
|
return size != 1;
|
||||||
|
}))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) {
|
||||||
|
if (!knowledge)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
auto iter = knowledge->aliases.find(value);
|
||||||
|
while (iter != knowledge->aliases.end()) {
|
||||||
|
value = iter->second;
|
||||||
|
iter = knowledge->aliases.find(value);
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Walks through view-like ops and DPS tied operands to find the "underlying" memref value
|
||||||
|
// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop
|
||||||
|
// and when propagating yielded values across iterations during static unrolling.
|
||||||
|
static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) {
|
||||||
|
value = resolveAlias(value, knowledge);
|
||||||
|
|
||||||
|
if (auto blockArgument = dyn_cast<BlockArgument>(value))
|
||||||
|
return value;
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||||
|
if (auto result = dyn_cast<OpResult>(value))
|
||||||
|
if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
|
||||||
|
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto castOp = dyn_cast<memref::CastOp>(definingOp))
|
||||||
|
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
||||||
|
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp))
|
||||||
|
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
|
||||||
|
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp))
|
||||||
|
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
|
||||||
|
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||||
|
|
||||||
|
static FailureOr<int64_t> resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) {
|
||||||
|
value = resolveAlias(value, knowledge);
|
||||||
|
|
||||||
|
if (knowledge) {
|
||||||
|
auto iter = knowledge->indexValues.find(value);
|
||||||
|
if (iter != knowledge->indexValues.end())
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
|
||||||
|
if (constantOp) {
|
||||||
|
if (auto integerAttr = dyn_cast<IntegerAttr>(constantOp.getValue()))
|
||||||
|
return integerAttr.getInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (auto indexCastOp = dyn_cast<arith::IndexCastOp>(definingOp))
|
||||||
|
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
|
||||||
|
|
||||||
|
if (auto addOp = dyn_cast<arith::AddIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs))
|
||||||
|
return failure();
|
||||||
|
return *lhs + *rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto subOp = dyn_cast<arith::SubIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs))
|
||||||
|
return failure();
|
||||||
|
return *lhs - *rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto mulOp = dyn_cast<arith::MulIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs))
|
||||||
|
return failure();
|
||||||
|
return *lhs * *rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto divOp = dyn_cast<arith::DivUIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||||
|
return failure();
|
||||||
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto remOp = dyn_cast<arith::RemUIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||||
|
return failure();
|
||||||
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
|
||||||
|
if (auto attr = dyn_cast<Attribute>(ofr)) {
|
||||||
|
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
||||||
|
if (!integerAttr)
|
||||||
|
return failure();
|
||||||
|
return integerAttr.getInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolveIndexValueImpl(cast<Value>(ofr), knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(Value value,
|
||||||
|
const StaticValueKnowledge* knowledge) {
|
||||||
|
int64_t byteOffset = 0;
|
||||||
|
value = resolveAlias(value, knowledge);
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (isa<BlockArgument>(value))
|
||||||
|
return ResolvedContiguousAddress {value, byteOffset};
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||||
|
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
|
||||||
|
if (!tiedOperand)
|
||||||
|
return failure();
|
||||||
|
value = resolveAlias(tiedOperand->get(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto forOp = dyn_cast<scf::ForOp>(definingOp)) {
|
||||||
|
auto result = dyn_cast<OpResult>(value);
|
||||||
|
if (!result)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Trace the loop carry back to its underlying memref, then if that memref is the
|
||||||
|
// loop's own iter-arg we know the base comes from the corresponding init arg
|
||||||
|
// (every iteration yields the same backing memory in the DPS sense).
|
||||||
|
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||||
|
Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||||
|
if (auto blockArgument = dyn_cast<BlockArgument>(yieldedValue)) {
|
||||||
|
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||||
|
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
||||||
|
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
value = yieldedValue;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
|
||||||
|
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||||
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t> offsets;
|
||||||
|
SmallVector<int64_t> sizes;
|
||||||
|
SmallVector<int64_t> strides;
|
||||||
|
offsets.reserve(subviewOp.getMixedOffsets().size());
|
||||||
|
sizes.reserve(subviewOp.getMixedSizes().size());
|
||||||
|
strides.reserve(subviewOp.getMixedStrides().size());
|
||||||
|
|
||||||
|
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||||
|
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
|
||||||
|
if (failed(resolvedOffset))
|
||||||
|
return failure();
|
||||||
|
offsets.push_back(*resolvedOffset);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||||
|
auto resolvedSize = resolveOpFoldResult(size, knowledge);
|
||||||
|
if (failed(resolvedSize))
|
||||||
|
return failure();
|
||||||
|
sizes.push_back(*resolvedSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||||
|
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
|
||||||
|
if (failed(resolvedStride))
|
||||||
|
return failure();
|
||||||
|
strides.push_back(*resolvedStride);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
|
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
||||||
|
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||||
|
value = resolveAlias(castOp.getSource(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||||
|
value = resolveAlias(collapseOp.getSrc(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||||
|
value = resolveAlias(expandOp.getSrc(), knowledge);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
||||||
|
return ResolvedContiguousAddress {value, byteOffset};
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<int64_t> resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); }
|
||||||
|
|
||||||
|
FailureOr<int64_t> resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||||
|
return resolveIndexValueImpl(value, &knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||||
|
return resolveContiguousAddressImpl(value, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) {
|
||||||
|
return resolveContiguousAddressImpl(value, &knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) {
|
||||||
|
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isCoreStaticAddressOp(Operation* op) {
|
||||||
|
return isa<arith::ConstantOp,
|
||||||
|
arith::AddIOp,
|
||||||
|
arith::SubIOp,
|
||||||
|
arith::MulIOp,
|
||||||
|
arith::DivUIOp,
|
||||||
|
arith::RemUIOp,
|
||||||
|
arith::IndexCastOp,
|
||||||
|
memref::AllocOp,
|
||||||
|
memref::SubViewOp,
|
||||||
|
memref::CastOp,
|
||||||
|
memref::CollapseShapeOp,
|
||||||
|
memref::ExpandShapeOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult walkPimCoreBlock(Block& block,
|
||||||
|
const StaticValueKnowledge& knowledge,
|
||||||
|
llvm::function_ref<LogicalResult(Operation&, const StaticValueKnowledge&)> callback) {
|
||||||
|
bool hasFailure = false;
|
||||||
|
for (Operation& op : block) {
|
||||||
|
if (isa<pim::PimHaltOp, scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
|
||||||
|
Block& loopBody = forOp.getRegion().front();
|
||||||
|
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
||||||
|
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
||||||
|
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
||||||
|
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
|
||||||
|
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
|
||||||
|
hasFailure = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
|
||||||
|
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
|
||||||
|
StaticValueKnowledge loopKnowledge = knowledge;
|
||||||
|
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
||||||
|
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
|
||||||
|
loopKnowledge.aliases[iterArg] = iterValue;
|
||||||
|
|
||||||
|
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
|
||||||
|
hasFailure = true;
|
||||||
|
|
||||||
|
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());
|
||||||
|
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
|
||||||
|
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (failed(callback(op, knowledge)))
|
||||||
|
hasFailure = true;
|
||||||
|
}
|
||||||
|
return success(!hasFailure);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -7,21 +7,82 @@
|
|||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
|
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
inline constexpr llvm::StringLiteral kCoreIdAttrName = "core_id";
|
struct ResolvedContiguousAddress {
|
||||||
|
mlir::Value base;
|
||||||
|
int64_t byteOffset = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct StaticValueKnowledge {
|
||||||
|
llvm::DenseMap<mlir::Value, int64_t> indexValues;
|
||||||
|
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
|
||||||
|
|
||||||
|
StaticValueKnowledge() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string getOutputDir();
|
||||||
|
|
||||||
|
void createDirectory(const std::string& directory);
|
||||||
|
|
||||||
|
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
||||||
|
|
||||||
|
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
|
||||||
|
|
||||||
|
bool hasWeightAlways(mlir::Operation* op);
|
||||||
|
|
||||||
|
void markWeightAlways(mlir::Operation* op);
|
||||||
|
|
||||||
|
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||||
|
|
||||||
|
llvm::FailureOr<mlir::Operation*>
|
||||||
|
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t>
|
||||||
|
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
||||||
|
|
||||||
|
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||||
|
llvm::ArrayRef<int64_t> offsets,
|
||||||
|
llvm::ArrayRef<int64_t> sizes,
|
||||||
|
llvm::ArrayRef<int64_t> strides);
|
||||||
|
|
||||||
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
||||||
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||||
|
const StaticValueKnowledge& knowledge);
|
||||||
|
|
||||||
|
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
|
||||||
|
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||||
|
|
||||||
|
/// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for
|
||||||
|
/// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries.
|
||||||
|
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||||
|
|
||||||
|
/// Returns true for ops inside a pim.core body that do not emit any PIM instruction and
|
||||||
|
/// only contribute to static addressing or index computations (arith integer math,
|
||||||
|
/// memref view ops, memref.alloc, arith.constant).
|
||||||
|
bool isCoreStaticAddressOp(mlir::Operation* op);
|
||||||
|
|
||||||
|
/// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically
|
||||||
|
/// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op
|
||||||
|
/// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is
|
||||||
|
/// invoked with the op and the in-scope knowledge. The walker keeps going after a callback
|
||||||
|
/// failure so callers can collect multiple diagnostics, but propagates the overall result.
|
||||||
|
mlir::LogicalResult
|
||||||
|
walkPimCoreBlock(mlir::Block& block,
|
||||||
|
const StaticValueKnowledge& knowledge,
|
||||||
|
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
#include "llvm/Support/raw_os_ostream.h"
|
|
||||||
|
|
||||||
#include <fstream>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
|
|
||||||
std::string outputDir = getOutputDir();
|
|
||||||
if (outputDir.empty())
|
|
||||||
return;
|
|
||||||
|
|
||||||
std::string dialectsDir = outputDir + "/dialects";
|
|
||||||
createDirectory(dialectsDir);
|
|
||||||
|
|
||||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
|
||||||
llvm::raw_os_ostream os(file);
|
|
||||||
mlir::OpPrintingFlags flags;
|
|
||||||
flags.elideLargeElementsAttrs();
|
|
||||||
moduleOp.print(os, flags);
|
|
||||||
os.flush();
|
|
||||||
file.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
/// Emits a MLIR snapshot under the current compiler output
|
|
||||||
/// directory for pass-level debugging.
|
|
||||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir::pim {
|
|
||||||
|
|
||||||
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription) {
|
|
||||||
return op->emitOpError() << "requires statically shaped " << valueDescription;
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
|
|
||||||
llvm::StringRef valueDescription,
|
|
||||||
int64_t actualRank,
|
|
||||||
llvm::ArrayRef<int64_t> supportedRanks) {
|
|
||||||
auto diag = op->emitOpError() << "has unsupported rank " << actualRank << " for " << valueDescription;
|
|
||||||
if (supportedRanks.empty())
|
|
||||||
return diag;
|
|
||||||
|
|
||||||
diag << "; supported rank";
|
|
||||||
if (supportedRanks.size() != 1)
|
|
||||||
diag << 's';
|
|
||||||
diag << ' ';
|
|
||||||
|
|
||||||
llvm::interleaveComma(supportedRanks, diag, [&](int64_t rank) { diag << rank; });
|
|
||||||
return diag;
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::InFlightDiagnostic
|
|
||||||
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName) {
|
|
||||||
return op->emitOpError() << "references missing " << symbolKind << " `" << symbolName << "`";
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::LogicalResult emitFileSystemError(mlir::Location loc,
|
|
||||||
llvm::StringRef action,
|
|
||||||
llvm::StringRef path,
|
|
||||||
const std::error_code& errorCode) {
|
|
||||||
mlir::emitError(loc) << "failed to " << action << " `" << path << "`: " << errorCode.message();
|
|
||||||
return mlir::failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir::pim
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/Diagnostics.h"
|
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/Support/LogicalResult.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
|
||||||
|
|
||||||
#include <system_error>
|
|
||||||
|
|
||||||
namespace onnx_mlir::pim {
|
|
||||||
|
|
||||||
/// Emits a consistent diagnostic for target paths that require static shapes.
|
|
||||||
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
|
|
||||||
|
|
||||||
/// Emits a consistent diagnostic for unsupported ranks while listing the ranks
|
|
||||||
/// accepted by the current lowering/codegen path.
|
|
||||||
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
|
|
||||||
llvm::StringRef valueDescription,
|
|
||||||
int64_t actualRank,
|
|
||||||
llvm::ArrayRef<int64_t> supportedRanks);
|
|
||||||
|
|
||||||
/// Emits a consistent diagnostic for missing symbol/global references.
|
|
||||||
mlir::InFlightDiagnostic
|
|
||||||
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName);
|
|
||||||
|
|
||||||
/// Converts a filesystem error into an MLIR failure diagnostic anchored at
|
|
||||||
/// the relevant IR location.
|
|
||||||
mlir::LogicalResult
|
|
||||||
emitFileSystemError(mlir::Location loc, llvm::StringRef action, llvm::StringRef path, const std::error_code& errorCode);
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
mlir::LogicalResult failureOrToLogicalResult(const llvm::FailureOr<T>& value) {
|
|
||||||
return mlir::success(succeeded(value));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir::pim
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
#include <filesystem>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
std::string getOutputDir() {
|
|
||||||
if (outputBaseName.empty() || outputBaseName == "-")
|
|
||||||
return {};
|
|
||||||
|
|
||||||
size_t lastSlash = outputBaseName.find_last_of('/');
|
|
||||||
if (lastSlash == std::string::npos)
|
|
||||||
return ".";
|
|
||||||
return outputBaseName.substr(0, lastSlash);
|
|
||||||
}
|
|
||||||
|
|
||||||
void createDirectory(const std::string& directory) {
|
|
||||||
std::error_code errorCode;
|
|
||||||
std::filesystem::create_directories(directory, errorCode);
|
|
||||||
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
/// Returns the directory that should hold PIM artifacts/debug dumps for the
|
|
||||||
/// current compiler invocation.
|
|
||||||
std::string getOutputDir();
|
|
||||||
|
|
||||||
void createDirectory(const std::string& directory);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
+123
-324
@@ -1,15 +1,11 @@
|
|||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/ADT/StringExtras.h"
|
|
||||||
#include "llvm/Support/FileSystem.h"
|
#include "llvm/Support/FileSystem.h"
|
||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
@@ -57,23 +53,9 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
|||||||
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
||||||
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
|
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
|
||||||
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
|
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
|
||||||
SmallVector<mlir::Value> args;
|
|
||||||
|
|
||||||
|
|
||||||
for (mlir::Value arg : funcOp.getArguments()){
|
|
||||||
gatherMemEntry(arg);
|
|
||||||
args.push_back(arg);
|
|
||||||
}
|
|
||||||
|
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
if (!hasWeightAlways(getGlobalOp)) {
|
if (!hasWeightAlways(getGlobalOp)) {
|
||||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
if (globalMemrefOp.getName().starts_with("arg")){
|
|
||||||
StringRef indexStr = globalMemrefOp.getName().substr(4);
|
|
||||||
int index = 0;
|
|
||||||
llvm::to_integer(indexStr,index, 10);
|
|
||||||
globalAliases.push_back({getGlobalOp.getResult(), args[index]});
|
|
||||||
}
|
|
||||||
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
|
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
|
||||||
if (inserted)
|
if (inserted)
|
||||||
gatherMemEntry(getGlobalOp.getResult());
|
gatherMemEntry(getGlobalOp.getResult());
|
||||||
@@ -82,6 +64,8 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
for (mlir::Value arg : funcOp.getArguments())
|
||||||
|
gatherMemEntry(arg);
|
||||||
|
|
||||||
funcOp.walk([&](memref::AllocOp allocOp) {
|
funcOp.walk([&](memref::AllocOp allocOp) {
|
||||||
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
||||||
@@ -147,12 +131,6 @@ json::Object PimCodeGen::createEmptyOffset() {
|
|||||||
return offset;
|
return offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t PimCodeGen::remapCoreId(size_t coreId) const {
|
|
||||||
auto it = emittedCoreIds.find(coreId);
|
|
||||||
assert(it != emittedCoreIds.end() && "Missing emitted core id remapping");
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
static json::Object createRs1OnlyOffset() {
|
static json::Object createRs1OnlyOffset() {
|
||||||
json::Object offset;
|
json::Object offset;
|
||||||
offset["offset_select"] = 1;
|
offset["offset_select"] = 1;
|
||||||
@@ -212,7 +190,7 @@ void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t
|
|||||||
json::Object json;
|
json::Object json;
|
||||||
json["op"] = opName;
|
json["op"] = opName;
|
||||||
json["rd"] = 0;
|
json["rd"] = 0;
|
||||||
json["core"] = remapCoreId(coreId);
|
json["core"] = coreId;
|
||||||
json["size"] = size;
|
json["size"] = size;
|
||||||
json["offset"] = createEmptyOffset();
|
json["offset"] = createEmptyOffset();
|
||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
@@ -434,9 +412,6 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
|
|||||||
emitInstruction(std::move(json));
|
emitInstruction(std::move(json));
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const {
|
|
||||||
}
|
|
||||||
|
|
||||||
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
|
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
|
||||||
auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
|
auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
|
||||||
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
|
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
|
||||||
@@ -499,136 +474,19 @@ std::string getMemorySizeAsString(size_t size) {
|
|||||||
return std::to_string(size) + " Bytes";
|
return std::to_string(size) + " Bytes";
|
||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
||||||
SmallVector<unsigned, 8> indices;
|
SmallVector<unsigned, 8> indices;
|
||||||
auto addIndex = [&](unsigned weightIndex) {
|
auto addIndex = [&](unsigned weightIndex) {
|
||||||
if (!llvm::is_contained(indices, weightIndex))
|
if (!llvm::is_contained(indices, weightIndex))
|
||||||
indices.push_back(weightIndex);
|
indices.push_back(weightIndex);
|
||||||
};
|
};
|
||||||
|
|
||||||
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
coreOp.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
||||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
coreOp.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||||
llvm::sort(indices);
|
llvm::sort(indices);
|
||||||
return indices;
|
return indices;
|
||||||
}
|
}
|
||||||
|
|
||||||
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
|
||||||
return getUsedWeightIndices(coreOp.getBody().front());
|
|
||||||
}
|
|
||||||
|
|
||||||
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
|
||||||
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdAttrName);
|
|
||||||
assert(coreIdsAttr && "pim.core_batch requires core_id array attribute");
|
|
||||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
|
||||||
}
|
|
||||||
|
|
||||||
static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
|
||||||
SmallVector<Operation*> coreLikeOps;
|
|
||||||
for (Operation& op : funcOp.getBody().front()) {
|
|
||||||
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
|
|
||||||
coreLikeOps.push_back(&op);
|
|
||||||
}
|
|
||||||
return coreLikeOps;
|
|
||||||
}
|
|
||||||
|
|
||||||
static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, unsigned lane) {
|
|
||||||
OpBuilder builder(coreBatchOp);
|
|
||||||
builder.setInsertionPointAfter(coreBatchOp);
|
|
||||||
|
|
||||||
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
|
||||||
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
|
|
||||||
SmallVector<mlir::Value> laneWeights;
|
|
||||||
laneWeights.reserve(weightsPerLane);
|
|
||||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
|
|
||||||
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
|
|
||||||
|
|
||||||
auto coreIds = getBatchCoreIds(coreBatchOp);
|
|
||||||
auto scalarCore = pim::PimCoreOp::create(builder,
|
|
||||||
coreBatchOp.getLoc(),
|
|
||||||
ValueRange(laneWeights),
|
|
||||||
builder.getI32IntegerAttr(coreIds[lane]));
|
|
||||||
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
|
|
||||||
IRMapping mapper;
|
|
||||||
if (coreBatchOp.getBody().front().getNumArguments() == 1)
|
|
||||||
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
|
|
||||||
|
|
||||||
builder.setInsertionPointToEnd(block);
|
|
||||||
for (Operation& op : coreBatchOp.getBody().front()) {
|
|
||||||
if (isa<pim::PimHaltOp>(op)) {
|
|
||||||
pim::PimHaltOp::create(builder, op.getLoc());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
|
||||||
pim::PimSendOp::create(builder,
|
|
||||||
sendBatchOp.getLoc(),
|
|
||||||
mapper.lookup(sendBatchOp.getInput()),
|
|
||||||
sendBatchOp.getSizeAttr(),
|
|
||||||
builder.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
|
||||||
auto scalarReceive = pim::PimReceiveOp::create(builder,
|
|
||||||
receiveBatchOp.getLoc(),
|
|
||||||
receiveBatchOp.getOutput().getType(),
|
|
||||||
mapper.lookup(receiveBatchOp.getOutputBuffer()),
|
|
||||||
receiveBatchOp.getSizeAttr(),
|
|
||||||
builder.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
|
|
||||||
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
|
|
||||||
mlir::Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
|
|
||||||
if (!hostSource)
|
|
||||||
hostSource = memcpBatchOp.getHostSource();
|
|
||||||
|
|
||||||
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(builder,
|
|
||||||
memcpBatchOp.getLoc(),
|
|
||||||
memcpBatchOp.getOutput().getType(),
|
|
||||||
mapper.lookup(memcpBatchOp.getDeviceTarget()),
|
|
||||||
hostSource,
|
|
||||||
memcpBatchOp.getDeviceTargetOffsetAttr(),
|
|
||||||
memcpBatchOp.getHostSourceOffsetAttr(),
|
|
||||||
memcpBatchOp.getSizeAttr());
|
|
||||||
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
Operation* cloned = builder.clone(op, mapper);
|
|
||||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
|
||||||
mapper.map(originalResult, clonedResult);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
|
|
||||||
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
|
|
||||||
return scalarCore;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void aliasMaterializedHostGlobals(
|
|
||||||
ModuleOp moduleOp, func::FuncOp funcOp, pim::PimCoreOp coreOp, PimAcceleratorMemory& memory) {
|
|
||||||
coreOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
|
||||||
if (hasWeightAlways(getGlobalOp) || memory.memEntriesMap.contains(getGlobalOp.getResult()))
|
|
||||||
return;
|
|
||||||
|
|
||||||
auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
|
||||||
if (!targetGlobal)
|
|
||||||
return;
|
|
||||||
|
|
||||||
mlir::Value aliasedValue;
|
|
||||||
funcOp.walk([&](memref::GetGlobalOp candidate) {
|
|
||||||
if (aliasedValue || candidate == getGlobalOp || !memory.memEntriesMap.contains(candidate.getResult()))
|
|
||||||
return;
|
|
||||||
if (lookupGlobalForGetGlobal(moduleOp, candidate) == targetGlobal)
|
|
||||||
aliasedValue = candidate.getResult();
|
|
||||||
});
|
|
||||||
|
|
||||||
if (aliasedValue)
|
|
||||||
memory.memEntriesMap[getGlobalOp.getResult()] = memory.memEntriesMap[aliasedValue];
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Write global constant data into a binary memory image at their allocated addresses.
|
/// Write global constant data into a binary memory image at their allocated addresses.
|
||||||
static OnnxMlirCompilerErrorCodes
|
static OnnxMlirCompilerErrorCodes
|
||||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||||
@@ -723,8 +581,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
|||||||
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
|
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
|
||||||
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
|
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
|
||||||
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
|
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
|
||||||
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
|
|
||||||
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
|
|
||||||
else {
|
else {
|
||||||
op.emitError("Unsupported codegen for this operation");
|
op.emitError("Unsupported codegen for this operation");
|
||||||
op.dump();
|
op.dump();
|
||||||
@@ -814,7 +670,7 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
|
|||||||
return CompilerSuccess;
|
return CompilerSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
|
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>>
|
||||||
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||||
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
|
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
|
||||||
auto coreWeightsDirPath = outputDirPath + "/weights";
|
auto coreWeightsDirPath = outputDirPath + "/weights";
|
||||||
@@ -823,104 +679,85 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
|||||||
size_t indexFileName = 0;
|
size_t indexFileName = 0;
|
||||||
|
|
||||||
int64_t xbarSize = crossbarSize.getValue();
|
int64_t xbarSize = crossbarSize.getValue();
|
||||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
|
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
|
||||||
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
||||||
|
|
||||||
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||||
|
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||||
|
if (index >= coreOp.getWeights().size()) {
|
||||||
|
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||||
|
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||||
|
}
|
||||||
|
mlir::Value weight = coreOp.getWeights()[index];
|
||||||
|
|
||||||
for (Operation* op : coreLikeOps) {
|
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||||
SmallVector<pim::PimCoreOp> scalarCores;
|
if (!getGlobalOp) {
|
||||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
|
||||||
scalarCores.push_back(coreOp);
|
assert(!getGlobalOp && "Weight is not from a memref.get_global");
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
|
||||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
|
||||||
scalarCores.push_back(materializeScalarCoreFromBatchLane(coreBatchOp, lane));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (pim::PimCoreOp coreOp : scalarCores) {
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
|
if (!globalOp) {
|
||||||
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(index));
|
||||||
if (index >= coreOp.getWeights().size()) {
|
assert(!globalOp && "Could not find memref.global");
|
||||||
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
}
|
||||||
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
|
||||||
}
|
|
||||||
mlir::Value weight = coreOp.getWeights()[index];
|
|
||||||
|
|
||||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
auto initialValue = globalOp.getInitialValue();
|
||||||
if (!getGlobalOp) {
|
if (!initialValue) {
|
||||||
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
|
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(index));
|
||||||
assert(!getGlobalOp && "Weight is not from a memref.get_global");
|
assert(!initialValue && "memref.global has no initial value");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
|
||||||
if (!globalOp) {
|
if (!denseAttr) {
|
||||||
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(index));
|
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(index));
|
||||||
assert(!globalOp && "Could not find memref.global");
|
assert(!denseAttr && "memref.global initial value is not dense");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto initialValue = globalOp.getInitialValue();
|
if (mapGlobalOpToFileName.contains(globalOp)) {
|
||||||
if (!initialValue) {
|
auto& fileName = mapGlobalOpToFileName[globalOp];
|
||||||
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(index));
|
std::pair<mlir::Value, std::string> weightToFile = {weight, fileName};
|
||||||
assert(!initialValue && "memref.global has no initial value");
|
mapCoreWeightToFileName[coreOp].insert(weightToFile);
|
||||||
}
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
|
auto type = denseAttr.getType();
|
||||||
if (!denseAttr) {
|
auto shape = type.getShape();
|
||||||
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(index));
|
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
||||||
assert(!denseAttr && "memref.global initial value is not dense");
|
int64_t numRows = shape[0];
|
||||||
}
|
int64_t numCols = shape[1];
|
||||||
|
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||||
|
|
||||||
if (mapGlobalOpToFileName.contains(globalOp)) {
|
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
|
||||||
auto& fileName = mapGlobalOpToFileName[globalOp];
|
|
||||||
std::pair<mlir::Value, std::string> weightToFile = {weight, fileName};
|
|
||||||
mapCoreWeightToFileName[coreId].insert(weightToFile);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto type = denseAttr.getType();
|
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||||
auto shape = type.getShape();
|
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||||
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
std::error_code errorCode;
|
||||||
int64_t numRows = shape[0];
|
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
||||||
int64_t numCols = shape[1];
|
if (errorCode) {
|
||||||
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
||||||
|
assert(errorCode);
|
||||||
|
}
|
||||||
|
|
||||||
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
|
uint64_t zero = 0;
|
||||||
|
for (int64_t row = 0; row < xbarSize; row++) {
|
||||||
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
for (int64_t col = 0; col < xbarSize; col++) {
|
||||||
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
if (row < numRows && col < numCols) {
|
||||||
std::error_code errorCode;
|
int64_t index = row * numCols + col;
|
||||||
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
|
||||||
if (errorCode) {
|
uint64_t word = bits.getZExtValue();
|
||||||
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
||||||
assert(errorCode);
|
}
|
||||||
}
|
else {
|
||||||
|
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
||||||
uint64_t zero = 0;
|
|
||||||
for (int64_t row = 0; row < xbarSize; row++) {
|
|
||||||
for (int64_t col = 0; col < xbarSize; col++) {
|
|
||||||
if (row < numRows && col < numCols) {
|
|
||||||
int64_t index = row * numCols + col;
|
|
||||||
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
|
|
||||||
uint64_t word = bits.getZExtValue();
|
|
||||||
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
weightFileStream.close();
|
|
||||||
mapGlobalOpToFileName.insert({globalOp, newFileName});
|
|
||||||
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for (pim::PimCoreOp coreOp : scalarCores)
|
weightFileStream.close();
|
||||||
if (coreOp.getOperation() != op)
|
mapGlobalOpToFileName.insert({globalOp, newFileName});
|
||||||
coreOp.erase();
|
mapCoreWeightToFileName[coreOp].insert({weight, newFileName});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return mapCoreWeightToFileName;
|
return mapCoreWeightToFileName;
|
||||||
}
|
}
|
||||||
@@ -928,14 +765,13 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
|||||||
/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses).
|
/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses).
|
||||||
static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
||||||
PimAcceleratorMemory& memory,
|
PimAcceleratorMemory& memory,
|
||||||
size_t maxCoreId,
|
size_t coreCount,
|
||||||
json::Object xbarsPerArrayGroup,
|
json::Object xbarsPerArrayGroup,
|
||||||
StringRef outputDirPath) {
|
StringRef outputDirPath) {
|
||||||
json::Object configJson;
|
json::Object configJson;
|
||||||
|
|
||||||
// pimsim-nn indexes cores directly by their numeric core ID, with the host
|
// +1 because pimsim-nn also considers the host as a core
|
||||||
// occupying core 0.
|
configJson["core_cnt"] = coreCount + 1;
|
||||||
configJson["core_cnt"] = maxCoreId + 1;
|
|
||||||
|
|
||||||
// TODO: Should this be based on the floating point type used in the model?
|
// TODO: Should this be based on the floating point type used in the model?
|
||||||
// The 2 following values determine the bitwidth of the vectors' elements: bitwidth = adc_count * cell_precision
|
// The 2 following values determine the bitwidth of the vectors' elements: bitwidth = adc_count * cell_precision
|
||||||
@@ -1009,103 +845,66 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
// For each core, specify the number of crossbar per array group.
|
// For each core, specify the number of crossbar per array group.
|
||||||
// This implementation always assigns one crossbar per group.
|
// This implementation always assigns one crossbar per group.
|
||||||
json::Object xbarsPerArrayGroup;
|
json::Object xbarsPerArrayGroup;
|
||||||
size_t maxCoreId = 0;
|
size_t coreCount = 0;
|
||||||
|
|
||||||
// Create Weight Folder
|
// Create Weight Folder
|
||||||
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
||||||
|
|
||||||
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
for (auto coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||||
llvm::DenseMap<size_t, size_t> emittedCoreIds;
|
auto coreId = coreOp.getCoreId();
|
||||||
size_t nextEmittedCoreId = 1;
|
coreCount++;
|
||||||
|
|
||||||
for (Operation* op : coreLikeOps) {
|
std::error_code errorCode;
|
||||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
|
||||||
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
|
raw_fd_ostream coreFileStream(outputCorePath, errorCode);
|
||||||
if (!emittedCoreIds.contains(originalCoreId))
|
if (errorCode) {
|
||||||
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
|
errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n';
|
||||||
continue;
|
return InvalidOutputFileAccess;
|
||||||
|
}
|
||||||
|
coreFileStream << '[';
|
||||||
|
|
||||||
|
PimCodeGen coreCodeGen(memory, coreFileStream);
|
||||||
|
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
|
||||||
|
|
||||||
|
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
|
||||||
|
if (processedOperations < 0)
|
||||||
|
return CompilerFailure;
|
||||||
|
assert(processedOperations > 0);
|
||||||
|
|
||||||
|
// Remove trailing comma, close JSON array
|
||||||
|
coreFileStream.seek(coreFileStream.tell() - 1);
|
||||||
|
coreFileStream << ']';
|
||||||
|
coreFileStream.close();
|
||||||
|
|
||||||
|
// Write crossbar weights for this core
|
||||||
|
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
|
||||||
|
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
|
||||||
|
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
|
||||||
|
return InvalidOutputFileAccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
auto& mapWeightToFile = mapCoreWeightToFileName[coreOp];
|
||||||
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
|
json::Array xbarsPerGroup;
|
||||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||||
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
if (index >= coreOp.getWeights().size()) {
|
||||||
if (!emittedCoreIds.contains(originalCoreId))
|
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||||
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
|
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||||
}
|
}
|
||||||
}
|
mlir::Value weight = coreOp.getWeights()[index];
|
||||||
|
xbarsPerGroup.push_back(index);
|
||||||
for (Operation* op : coreLikeOps) {
|
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
|
||||||
SmallVector<pim::PimCoreOp> scalarCores;
|
auto& fileName = mapWeightToFile[weight];
|
||||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
|
||||||
scalarCores.push_back(coreOp);
|
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
|
||||||
}
|
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
|
||||||
else {
|
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") << "\nError:" << error.message()
|
||||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
<< '\n';
|
||||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
|
||||||
scalarCores.push_back(materializeScalarCoreFromBatchLane(coreBatchOp, lane));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (pim::PimCoreOp coreOp : scalarCores) {
|
|
||||||
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
|
|
||||||
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
|
||||||
maxCoreId = std::max(maxCoreId, coreId);
|
|
||||||
|
|
||||||
std::error_code errorCode;
|
|
||||||
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
|
|
||||||
raw_fd_ostream coreFileStream(outputCorePath, errorCode);
|
|
||||||
if (errorCode) {
|
|
||||||
errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n';
|
|
||||||
return InvalidOutputFileAccess;
|
return InvalidOutputFileAccess;
|
||||||
}
|
}
|
||||||
coreFileStream << '[';
|
|
||||||
|
|
||||||
PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
|
|
||||||
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
|
|
||||||
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
|
|
||||||
|
|
||||||
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
|
|
||||||
if (processedOperations < 0)
|
|
||||||
return CompilerFailure;
|
|
||||||
assert(processedOperations > 0);
|
|
||||||
|
|
||||||
coreFileStream.seek(coreFileStream.tell() - 1);
|
|
||||||
coreFileStream << ']';
|
|
||||||
coreFileStream.close();
|
|
||||||
|
|
||||||
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
|
|
||||||
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
|
|
||||||
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
|
|
||||||
return InvalidOutputFileAccess;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& mapWeightToFile = mapCoreWeightToFileName[originalCoreId];
|
|
||||||
json::Array xbarsPerGroup;
|
|
||||||
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
|
||||||
if (index >= coreOp.getWeights().size()) {
|
|
||||||
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
|
||||||
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
|
||||||
}
|
|
||||||
mlir::Value weight = coreOp.getWeights()[index];
|
|
||||||
xbarsPerGroup.push_back(index);
|
|
||||||
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
|
|
||||||
auto& fileName = mapWeightToFile[weight];
|
|
||||||
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
|
|
||||||
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
|
|
||||||
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
|
|
||||||
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") << "\nError:"
|
|
||||||
<< error.message() << '\n';
|
|
||||||
return InvalidOutputFileAccess;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (pim::PimCoreOp coreOp : scalarCores)
|
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
|
||||||
if (coreOp.getOperation() != op)
|
|
||||||
coreOp.erase();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath);
|
return writeConfigJson(funcOp, memory, coreCount, std::move(xbarsPerArrayGroup), outputDirPath);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
|
||||||
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
|
|
||||||
@@ -59,12 +58,10 @@ public:
|
|||||||
class PimCodeGen {
|
class PimCodeGen {
|
||||||
PimAcceleratorMemory& memory;
|
PimAcceleratorMemory& memory;
|
||||||
llvm::raw_fd_ostream& coreFileStream;
|
llvm::raw_fd_ostream& coreFileStream;
|
||||||
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
|
|
||||||
|
|
||||||
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||||
return memory.getValueAddress(value, knowledge);
|
return memory.getValueAddress(value, knowledge);
|
||||||
}
|
}
|
||||||
size_t remapCoreId(size_t coreId) const;
|
|
||||||
|
|
||||||
static llvm::json::Object createEmptyOffset();
|
static llvm::json::Object createEmptyOffset();
|
||||||
void emitInstruction(llvm::json::Object instruction) const;
|
void emitInstruction(llvm::json::Object instruction) const;
|
||||||
@@ -86,10 +83,8 @@ class PimCodeGen {
|
|||||||
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
|
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PimCodeGen(PimAcceleratorMemory& memory,
|
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
|
||||||
llvm::raw_fd_ostream& coreJson,
|
: memory(memory), coreFileStream(coreJson) {}
|
||||||
const llvm::DenseMap<size_t, size_t>& emittedCoreIds)
|
|
||||||
: memory(memory), coreFileStream(coreJson), emittedCoreIds(emittedCoreIds) {}
|
|
||||||
|
|
||||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
||||||
@@ -111,7 +106,6 @@ public:
|
|||||||
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
|
||||||
void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const;
|
|
||||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
|
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ llvm::cl::opt<size_t>
|
|||||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
|
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
|
||||||
|
|
||||||
llvm::cl::opt<size_t>
|
llvm::cl::opt<size_t>
|
||||||
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
|
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2));
|
||||||
|
|
||||||
llvm::cl::opt<long> coresCount("core-count",
|
llvm::cl::opt<long> coresCount("core-count",
|
||||||
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
||||||
@@ -51,7 +51,7 @@ llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
|||||||
"dcp-critical-window-size",
|
"dcp-critical-window-size",
|
||||||
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
||||||
"Use 0 to run the legacy full-graph DCP analysis."),
|
"Use 0 to run the legacy full-graph DCP analysis."),
|
||||||
llvm::cl::init(4000));
|
llvm::cl::init(1024));
|
||||||
|
|
||||||
llvm::cl::opt<bool>
|
llvm::cl::opt<bool>
|
||||||
ignoreConcatError("ignore-concat-error",
|
ignoreConcatError("ignore-concat-error",
|
||||||
|
|||||||
@@ -12,7 +12,6 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
@@ -175,31 +174,6 @@ using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
|||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
template <typename RewriterT>
|
|
||||||
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
|
|
||||||
assert(!inputs.empty() && "spat.concat requires at least one input");
|
|
||||||
if (inputs.size() == 1)
|
|
||||||
return inputs.front();
|
|
||||||
|
|
||||||
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
|
|
||||||
auto outputShape = llvm::to_vector(firstType.getShape());
|
|
||||||
int64_t concatDimSize = 0;
|
|
||||||
bool concatDimDynamic = false;
|
|
||||||
|
|
||||||
for (mlir::Value input : inputs) {
|
|
||||||
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
|
|
||||||
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
|
|
||||||
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
|
|
||||||
concatDimDynamic = true;
|
|
||||||
else
|
|
||||||
concatDimSize += inputType.getDimSize(axis);
|
|
||||||
}
|
|
||||||
|
|
||||||
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
|
|
||||||
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
|
||||||
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||||
auto createSpatCompute(RewriterT& rewriter,
|
auto createSpatCompute(RewriterT& rewriter,
|
||||||
mlir::Location loc,
|
mlir::Location loc,
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
@@ -9,10 +8,10 @@
|
|||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallSet.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/ErrorHandling.h"
|
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@@ -25,6 +24,8 @@
|
|||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -51,48 +52,12 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
|||||||
private:
|
private:
|
||||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||||
|
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
|
||||||
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
|
|
||||||
IRRewriter rewriter(funcOp.getContext());
|
|
||||||
SmallVector<spatial::SpatComputeBatch> batchOps;
|
|
||||||
funcOp.walk([&](spatial::SpatComputeBatch batchOp) { batchOps.push_back(batchOp); });
|
|
||||||
|
|
||||||
for (auto batchOp : batchOps) {
|
|
||||||
if (batchOp.getLaneCount() != 1)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
auto loc = batchOp.getLoc();
|
|
||||||
rewriter.setInsertionPoint(batchOp);
|
|
||||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
|
||||||
computeOp.getProperties().setOperandSegmentSizes(
|
|
||||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
|
||||||
|
|
||||||
Block& templateBlock = batchOp.getBody().front();
|
|
||||||
SmallVector<Type> blockArgTypes;
|
|
||||||
SmallVector<Location> blockArgLocs;
|
|
||||||
for (BlockArgument arg : templateBlock.getArguments()) {
|
|
||||||
blockArgTypes.push_back(arg.getType());
|
|
||||||
blockArgLocs.push_back(loc);
|
|
||||||
}
|
|
||||||
auto* newBlock = rewriter.createBlock(
|
|
||||||
&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
|
||||||
|
|
||||||
IRMapping mapper;
|
|
||||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
|
||||||
mapper.map(oldArg, newArg);
|
|
||||||
rewriter.setInsertionPointToEnd(newBlock);
|
|
||||||
for (Operation& op : templateBlock)
|
|
||||||
rewriter.clone(op, mapper);
|
|
||||||
|
|
||||||
batchOp.replaceAllUsesWith(computeOp.getResults());
|
|
||||||
rewriter.eraseOp(batchOp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ONNXToSpatialPass::runOnOperation() {
|
void ONNXToSpatialPass::runOnOperation() {
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
MLIRContext* ctx = &getContext();
|
MLIRContext* ctx = &getContext();
|
||||||
@@ -122,7 +87,8 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
tensor::TensorDialect,
|
tensor::TensorDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
scf::SCFDialect>();
|
scf::SCFDialect>();
|
||||||
target.addIllegalOp<ONNXMatMulOp>();
|
target.addDynamicallyLegalOp<ONNXMatMulOp>(
|
||||||
|
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
|
||||||
target.addIllegalOp<ONNXAddOp>();
|
target.addIllegalOp<ONNXAddOp>();
|
||||||
target.addIllegalOp<ONNXDivOp>();
|
target.addIllegalOp<ONNXDivOp>();
|
||||||
target.addIllegalOp<ONNXMulOp>();
|
target.addIllegalOp<ONNXMulOp>();
|
||||||
@@ -163,8 +129,6 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
foldSingleLaneComputeBatches(*entryFunc);
|
|
||||||
|
|
||||||
// Count the number of compute ops and check they do not exceed the core count
|
// Count the number of compute ops and check they do not exceed the core count
|
||||||
if (coresCount != -1) {
|
if (coresCount != -1) {
|
||||||
int computeOpsCount = 0;
|
int computeOpsCount = 0;
|
||||||
@@ -185,7 +149,6 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
||||||
|
|
||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
|
|
||||||
encapsulateGlobalInstruction(*entryFunc);
|
encapsulateGlobalInstruction(*entryFunc);
|
||||||
|
|
||||||
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
||||||
@@ -193,6 +156,8 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mergeTriviallyConnectedComputes(*entryFunc);
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
dumpModule(moduleOp, "spatial0");
|
dumpModule(moduleOp, "spatial0");
|
||||||
}
|
}
|
||||||
@@ -202,36 +167,19 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func
|
|||||||
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
||||||
Value source = funcSource(toRemoveOp);
|
Value source = funcSource(toRemoveOp);
|
||||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
|
||||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||||
rewriter.setInsertionPointToEnd(BB);
|
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||||
IRMapping mapper;
|
rewriter.setInsertionPointToEnd(BB);
|
||||||
mapper.map(source, BB->getArgument(0));
|
IRMapping mapper;
|
||||||
auto newInst = rewriter.clone(*inst, mapper);
|
mapper.map(source, BB->getArgument(0));
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
auto newInst = rewriter.clone(*inst, mapper);
|
||||||
inst->replaceAllUsesWith(newCompute->getResults());
|
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||||
inst->erase();
|
inst->replaceAllUsesWith(newCompute->getResults());
|
||||||
return true;
|
inst->erase();
|
||||||
}
|
return true;
|
||||||
return false;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) {
|
|
||||||
if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present<tensor::ExtractSliceOp>(inst)) {
|
|
||||||
auto source = toRemoveOp.getSource();
|
|
||||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
|
||||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
|
||||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
|
||||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
|
||||||
rewriter.setInsertionPointToEnd(BB);
|
|
||||||
IRMapping mapper;
|
|
||||||
mapper.map(source, BB->getArgument(0));
|
|
||||||
auto newInst = rewriter.clone(*inst, mapper);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
|
||||||
inst->replaceAllUsesWith(newCompute->getResults());
|
|
||||||
inst->erase();
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -240,8 +188,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
|||||||
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
|
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
|
||||||
auto sources = toRemoveOp.getInputs();
|
auto sources = toRemoveOp.getInputs();
|
||||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||||
if (llvm::any_of(sources,
|
if (llvm::any_of(
|
||||||
[](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
sources, [](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
||||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
||||||
SmallVector<Type> sourceTypes;
|
SmallVector<Type> sourceTypes;
|
||||||
SmallVector<Location> sourceLoc;
|
SmallVector<Location> sourceLoc;
|
||||||
@@ -255,34 +203,12 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
|||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||||
mapper.map(source, bbArg);
|
mapper.map(source, bbArg);
|
||||||
auto newConcat = spatial::SpatConcatOp::create(rewriter,
|
auto newConcat = rewriter.clone(*inst, mapper);
|
||||||
loc,
|
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
||||||
toRemoveOp.getType(),
|
|
||||||
rewriter.getI64IntegerAttr(toRemoveOp.getDim()),
|
|
||||||
ValueRange(BB->getArguments()));
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat.getOutput());
|
|
||||||
inst->replaceAllUsesWith(newCompute->getResults());
|
inst->replaceAllUsesWith(newCompute->getResults());
|
||||||
inst->erase();
|
inst->erase();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
|
||||||
SmallVector<Type> sourceTypes;
|
|
||||||
SmallVector<Location> sourceLoc;
|
|
||||||
for (auto source : sources) {
|
|
||||||
sourceTypes.push_back(source.getType());
|
|
||||||
sourceLoc.push_back(loc);
|
|
||||||
}
|
|
||||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
|
||||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
|
||||||
rewriter.setInsertionPointToEnd(BB);
|
|
||||||
IRMapping mapper;
|
|
||||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
|
||||||
mapper.map(source, bbArg);
|
|
||||||
auto newConcat = rewriter.clone(*inst, mapper);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
|
||||||
inst->replaceAllUsesWith(newCompute->getResults());
|
|
||||||
inst->erase();
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -344,89 +270,6 @@ static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewrite
|
|||||||
return cast<Value>(mapped);
|
return cast<Value>(mapped);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool sourceOpernadHasWeightAlways(Operation* op) {
|
|
||||||
if (op == nullptr)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
Operation* source = nullptr;
|
|
||||||
do {
|
|
||||||
|
|
||||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch>(*op)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(*op)) {
|
|
||||||
auto tmpSource = extractSliceOp.getSource();
|
|
||||||
auto definingOp = tmpSource.getDefiningOp();
|
|
||||||
if (definingOp)
|
|
||||||
op = definingOp;
|
|
||||||
else
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
else if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(*op)) {
|
|
||||||
auto tmpSource = extractRowsOp.getInput();
|
|
||||||
auto definingOp = tmpSource.getDefiningOp();
|
|
||||||
if (definingOp)
|
|
||||||
op = definingOp;
|
|
||||||
else
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(*op)) {
|
|
||||||
auto tmpSource = expandShapeOp.getSrc();
|
|
||||||
auto definingOp = tmpSource.getDefiningOp();
|
|
||||||
if (definingOp)
|
|
||||||
op = definingOp;
|
|
||||||
else
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
else if (auto transposeOp = dyn_cast<ONNXTransposeOp>(*op)) {
|
|
||||||
auto tmpSource = transposeOp.getData();
|
|
||||||
auto definingOp = tmpSource.getDefiningOp();
|
|
||||||
if (definingOp)
|
|
||||||
op = definingOp;
|
|
||||||
else
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
else if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(*op)) {
|
|
||||||
auto tmpSource = collapseShapeOp.getSrc();
|
|
||||||
auto definingOp = tmpSource.getDefiningOp();
|
|
||||||
if (definingOp)
|
|
||||||
op = definingOp;
|
|
||||||
else
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
else if (auto constantOp = dyn_cast<arith::ConstantOp>(*op)) {
|
|
||||||
source = constantOp;
|
|
||||||
}
|
|
||||||
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(*op)) {
|
|
||||||
bool res = false;
|
|
||||||
for (auto operand : concatOp.getOperands()) {
|
|
||||||
res |= hasWeightAlways(operand.getDefiningOp());
|
|
||||||
if (res)
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
else if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(*op)) {
|
|
||||||
bool res = false;
|
|
||||||
for (auto operand : concatOp.getOperands()) {
|
|
||||||
res |= hasWeightAlways(operand.getDefiningOp());
|
|
||||||
if (res)
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
op->dump();
|
|
||||||
llvm_unreachable("Global instruction not handle in func");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
while (source == nullptr);
|
|
||||||
|
|
||||||
if (hasWeightAlways(source))
|
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO what we want to keep in global?
|
// TODO what we want to keep in global?
|
||||||
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
@@ -435,14 +278,8 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
|||||||
while (keep) {
|
while (keep) {
|
||||||
keep = false;
|
keep = false;
|
||||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
||||||
|
keep |= encapsulator<tensor::ExtractSliceOp>(
|
||||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
|
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
|
||||||
instruction)
|
|
||||||
|| isa<func::ReturnOp>(instruction)
|
|
||||||
|| sourceOpernadHasWeightAlways(&instruction))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
keep |= encapsulateSlice(rewriter, loc, &instruction);
|
|
||||||
|
|
||||||
keep |= encapsulator<tensor::ExpandShapeOp>(
|
keep |= encapsulator<tensor::ExpandShapeOp>(
|
||||||
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
|
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
|
||||||
@@ -458,9 +295,105 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||||
|
Location loc = funcOp.getLoc();
|
||||||
|
IRRewriter rewriter(&getContext());
|
||||||
|
SmallVector<spatial::SpatCompute> trivialComputes;
|
||||||
|
llvm::SmallSet<spatial::SpatCompute, 8> toErase;
|
||||||
|
|
||||||
|
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
|
||||||
|
if (compute->hasOneUse()) {
|
||||||
|
auto& use = *compute->getUses().begin();
|
||||||
|
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
||||||
|
|
||||||
|
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
||||||
|
trivialComputes.push_back(compute);
|
||||||
|
}
|
||||||
|
|
||||||
|
while (!trivialComputes.empty()) {
|
||||||
|
auto compute = trivialComputes.front();
|
||||||
|
|
||||||
|
if (compute.use_empty()) {
|
||||||
|
std::swap(trivialComputes.front(), trivialComputes.back());
|
||||||
|
trivialComputes.pop_back();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto& computeUse = *compute->getUses().begin();
|
||||||
|
auto child = cast<spatial::SpatCompute>(computeUse.getOwner());
|
||||||
|
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
|
||||||
|
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||||
|
|
||||||
|
auto newCompute =
|
||||||
|
spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
|
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
||||||
|
|
||||||
|
IRMapping mapper;
|
||||||
|
auto weightMutableIter = newCompute.getWeightsMutable();
|
||||||
|
for (auto weight : child.getWeights()) {
|
||||||
|
auto founded = llvm::find(newCompute.getWeights(), weight);
|
||||||
|
if (founded == newCompute.getWeights().end()) {
|
||||||
|
weightMutableIter.append(weight);
|
||||||
|
auto last = weightMutableIter.end();
|
||||||
|
last = std::prev(last, 1);
|
||||||
|
mapper.map(weight, last->get());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
mapper.map(weight, *founded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||||
|
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||||
|
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
|
||||||
|
newTerminator->erase();
|
||||||
|
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||||
|
for (auto& op : child.getBody().front()) {
|
||||||
|
auto newInst = rewriter.clone(op, mapper);
|
||||||
|
|
||||||
|
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
|
||||||
|
auto oldIndex = vmOp.getWeightIndex();
|
||||||
|
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
|
||||||
|
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
|
||||||
|
vmOp.setWeightIndex(newIndex);
|
||||||
|
}
|
||||||
|
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
|
||||||
|
auto oldIndex = vmOp.getWeightIndex();
|
||||||
|
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
|
||||||
|
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
|
||||||
|
vmOp.setWeightIndex(newIndex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
child.replaceAllUsesWith(newCompute);
|
||||||
|
toErase.insert(child);
|
||||||
|
|
||||||
|
std::swap(trivialComputes.front(), trivialComputes.back());
|
||||||
|
trivialComputes.pop_back();
|
||||||
|
toErase.insert(compute);
|
||||||
|
|
||||||
|
if (newCompute->hasOneUse()) {
|
||||||
|
auto& use = *newCompute->getUses().begin();
|
||||||
|
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
||||||
|
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
||||||
|
trivialComputes.push_back(newCompute);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto compute : toErase) {
|
||||||
|
for (Value result : compute->getResults())
|
||||||
|
result.dropAllUses();
|
||||||
|
compute.erase();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||||
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
bool isAlwaysWeight =
|
||||||
|
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); });
|
||||||
|
if (isAlwaysWeight)
|
||||||
markWeightAlways(constantOp);
|
markWeightAlways(constantOp);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -147,148 +147,162 @@ static Value buildPackedBias(bool hasBias,
|
|||||||
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createIm2colRowComputes(Value x,
|
static SmallVector<Value> createIm2colRowComputes(Value x,
|
||||||
RankedTensorType xType,
|
RankedTensorType xType,
|
||||||
RankedTensorType im2colType,
|
RankedTensorType im2colType,
|
||||||
RankedTensorType im2colRowType,
|
RankedTensorType im2colRowType,
|
||||||
RankedTensorType gemmInputRowsType,
|
RankedTensorType gemmInputRowType,
|
||||||
int64_t batchSize,
|
int64_t batchSize,
|
||||||
int64_t numChannelsIn,
|
int64_t numChannelsIn,
|
||||||
int64_t xHeight,
|
int64_t xHeight,
|
||||||
int64_t xWidth,
|
int64_t xWidth,
|
||||||
int64_t wHeight,
|
int64_t wHeight,
|
||||||
int64_t wWidth,
|
int64_t wWidth,
|
||||||
int64_t padHeightBegin,
|
int64_t padHeightBegin,
|
||||||
int64_t padHeightEnd,
|
int64_t padHeightEnd,
|
||||||
int64_t padWidthBegin,
|
int64_t padWidthBegin,
|
||||||
int64_t padWidthEnd,
|
int64_t padWidthEnd,
|
||||||
int64_t strideHeight,
|
int64_t strideHeight,
|
||||||
int64_t strideWidth,
|
int64_t strideWidth,
|
||||||
int64_t dilationHeight,
|
int64_t dilationHeight,
|
||||||
int64_t dilationWidth,
|
int64_t dilationWidth,
|
||||||
int64_t outWidth,
|
int64_t outWidth,
|
||||||
int64_t patchSize,
|
int64_t patchSize,
|
||||||
int64_t numPatches,
|
int64_t numPatches,
|
||||||
int64_t numPatchesPerBatch,
|
int64_t numPatchesPerBatch,
|
||||||
int64_t packFactor,
|
int64_t packFactor,
|
||||||
ConversionPatternRewriter& rewriter,
|
ConversionPatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
auto elemType = xType.getElementType();
|
auto elemType = xType.getElementType();
|
||||||
constexpr size_t numInputs = 1;
|
constexpr size_t numInputs = 1;
|
||||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||||
auto im2colComputeOp =
|
SmallVector<Type> resultTypes(packedNumRows, gemmInputRowType);
|
||||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
|
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, resultTypes, {}, x, [&](Value xArg) {
|
||||||
Value paddedInput = xArg;
|
Value paddedInput = xArg;
|
||||||
|
|
||||||
// Pad input with zeros if needed:
|
// Pad input with zeros if needed:
|
||||||
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
||||||
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||||
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||||
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||||
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||||
rewriter.getIndexAttr(0),
|
rewriter.getIndexAttr(0),
|
||||||
rewriter.getIndexAttr(padHeightBegin),
|
rewriter.getIndexAttr(padHeightBegin),
|
||||||
rewriter.getIndexAttr(padWidthBegin)};
|
rewriter.getIndexAttr(padWidthBegin)};
|
||||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||||
rewriter.getIndexAttr(0),
|
rewriter.getIndexAttr(0),
|
||||||
rewriter.getIndexAttr(padHeightEnd),
|
rewriter.getIndexAttr(padHeightEnd),
|
||||||
rewriter.getIndexAttr(padWidthEnd)};
|
rewriter.getIndexAttr(padWidthEnd)};
|
||||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
||||||
auto* padBlock = new Block();
|
auto* padBlock = new Block();
|
||||||
for (int i = 0; i < 4; i++)
|
for (int i = 0; i < 4; i++)
|
||||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||||
padOp.getRegion().push_back(padBlock);
|
padOp.getRegion().push_back(padBlock);
|
||||||
rewriter.setInsertionPointToStart(padBlock);
|
rewriter.setInsertionPointToStart(padBlock);
|
||||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||||
rewriter.setInsertionPointAfter(padOp);
|
rewriter.setInsertionPointAfter(padOp);
|
||||||
paddedInput = padOp.getResult();
|
paddedInput = padOp.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
|
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
|
||||||
// until the late PIM unrolling step.
|
// until the late PIM unrolling step.
|
||||||
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
|
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
|
||||||
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||||
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||||
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
|
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
|
||||||
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
|
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
|
||||||
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
|
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
|
||||||
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||||
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||||
|
|
||||||
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
|
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
|
||||||
rewriter.setInsertionPointToStart(im2colLoop.getBody());
|
rewriter.setInsertionPointToStart(im2colLoop.getBody());
|
||||||
|
|
||||||
Value patchIndex = im2colLoop.getInductionVar();
|
Value patchIndex = im2colLoop.getInductionVar();
|
||||||
Value im2colAcc = im2colLoop.getRegionIterArgs().front();
|
Value im2colAcc = im2colLoop.getRegionIterArgs().front();
|
||||||
|
|
||||||
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||||
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
||||||
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||||
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
||||||
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
||||||
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
||||||
|
|
||||||
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
|
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(numChannelsIn),
|
rewriter.getIndexAttr(numChannelsIn),
|
||||||
rewriter.getIndexAttr(wHeight),
|
rewriter.getIndexAttr(wHeight),
|
||||||
rewriter.getIndexAttr(wWidth)};
|
rewriter.getIndexAttr(wWidth)};
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(dilationHeight),
|
rewriter.getIndexAttr(dilationHeight),
|
||||||
rewriter.getIndexAttr(dilationWidth)};
|
rewriter.getIndexAttr(dilationWidth)};
|
||||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||||
|
|
||||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
im2colRowType,
|
im2colRowType,
|
||||||
patch,
|
patch,
|
||||||
SmallVector<ReassociationIndices> {
|
SmallVector<ReassociationIndices> {
|
||||||
{0},
|
{0},
|
||||||
{1, 2, 3}
|
{1, 2, 3}
|
||||||
});
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
|
|
||||||
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
|
|
||||||
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
Value updatedIm2col =
|
|
||||||
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
|
|
||||||
scf::YieldOp::create(rewriter, loc, updatedIm2col);
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(im2colLoop);
|
|
||||||
Value im2col = im2colLoop.getResult(0);
|
|
||||||
|
|
||||||
Value gemmInputRows = im2col;
|
|
||||||
if (packFactor != 1) {
|
|
||||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
|
||||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
|
||||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
|
||||||
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
|
||||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
groupedType,
|
|
||||||
paddedIm2col,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0, 1},
|
|
||||||
{2}
|
|
||||||
});
|
|
||||||
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
packedType,
|
|
||||||
groupedIm2col,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0},
|
|
||||||
{1, 2}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
return im2colComputeOp.getResult(0);
|
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
|
||||||
|
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
Value updatedIm2col =
|
||||||
|
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
|
||||||
|
scf::YieldOp::create(rewriter, loc, updatedIm2col);
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(im2colLoop);
|
||||||
|
Value im2col = im2colLoop.getResult(0);
|
||||||
|
|
||||||
|
Value gemmInputRows = im2col;
|
||||||
|
if (packFactor != 1) {
|
||||||
|
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||||
|
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||||
|
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||||
|
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
||||||
|
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
groupedType,
|
||||||
|
paddedIm2col,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
packedType,
|
||||||
|
groupedIm2col,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> rowResults;
|
||||||
|
rowResults.reserve(packedNumRows);
|
||||||
|
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||||
|
rewriter.getIndexAttr(packFactor * patchSize)};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
rowResults.push_back(
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
|
||||||
|
}
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, rowResults);
|
||||||
|
});
|
||||||
|
|
||||||
|
SmallVector<Value> rows;
|
||||||
|
rows.reserve(im2colComputeOp.getNumResults());
|
||||||
|
for (Value result : im2colComputeOp.getResults())
|
||||||
|
rows.push_back(result);
|
||||||
|
return rows;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createCollectedConvOutput(ValueRange gemmRows,
|
static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||||
@@ -306,12 +320,16 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
|||||||
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
||||||
Value gemmOut;
|
Value gemmOut;
|
||||||
if (packFactor == 1) {
|
if (packFactor == 1) {
|
||||||
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
gemmOut = gemmRowArgs.size() == 1 ? gemmRowArgs.front()
|
||||||
|
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||||
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
Value packedOutput =
|
||||||
|
gemmRowArgs.size() == 1
|
||||||
|
? gemmRowArgs.front()
|
||||||
|
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
expandedType,
|
expandedType,
|
||||||
@@ -487,42 +505,38 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||||
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
|
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
|
||||||
//
|
//
|
||||||
// We want to process N pixels at the same time. Instead of doing N separate operations
|
// The im2col compute yields each GEMM input row as a separate result so every GEMM consumes only
|
||||||
// of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
|
// the row it needs instead of receiving a full packed tensor and slicing it locally.
|
||||||
// containing N copies of W^T and concatenate N im2col rows into one longer row:
|
auto gemmInputRowType =
|
||||||
// A_packed: [ceil(numPatches / N), N * patchSize]
|
RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
|
||||||
// B_packed: [N * patchSize, N * cOut]
|
auto gemmOutputRowType =
|
||||||
// Y_packed: [ceil(numPatches / N), N * cOut]
|
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
|
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
|
||||||
auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType);
|
xType,
|
||||||
auto gemmOutputRowsType =
|
im2colType,
|
||||||
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
rowType,
|
||||||
Value gemmInputRows = createIm2colRowComputes(x,
|
gemmInputRowType,
|
||||||
xType,
|
batchSize,
|
||||||
im2colType,
|
numChannelsIn,
|
||||||
rowType,
|
xHeight,
|
||||||
gemmInputRowsType,
|
xWidth,
|
||||||
batchSize,
|
wHeight,
|
||||||
numChannelsIn,
|
wWidth,
|
||||||
xHeight,
|
padHeightBegin,
|
||||||
xWidth,
|
padHeightEnd,
|
||||||
wHeight,
|
padWidthBegin,
|
||||||
wWidth,
|
padWidthEnd,
|
||||||
padHeightBegin,
|
strideHeight,
|
||||||
padHeightEnd,
|
strideWidth,
|
||||||
padWidthBegin,
|
dilationHeight,
|
||||||
padWidthEnd,
|
dilationWidth,
|
||||||
strideHeight,
|
outWidth,
|
||||||
strideWidth,
|
patchSize,
|
||||||
dilationHeight,
|
numPatches,
|
||||||
dilationWidth,
|
numPatchesPerBatch,
|
||||||
outWidth,
|
effectiveMaxParallelPixels,
|
||||||
patchSize,
|
rewriter,
|
||||||
numPatches,
|
loc);
|
||||||
numPatchesPerBatch,
|
|
||||||
effectiveMaxParallelPixels,
|
|
||||||
rewriter,
|
|
||||||
loc);
|
|
||||||
|
|
||||||
Value gemmB = buildPackedWeight(wDenseAttr,
|
Value gemmB = buildPackedWeight(wDenseAttr,
|
||||||
wTrans,
|
wTrans,
|
||||||
@@ -538,20 +552,25 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
Value gemmC = buildPackedBias(
|
Value gemmC = buildPackedBias(
|
||||||
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||||
|
|
||||||
Value gemmRows = ONNXGemmOp::create(rewriter,
|
SmallVector<Value> gemmRows;
|
||||||
loc,
|
gemmRows.reserve(gemmInputRows.size());
|
||||||
gemmOutputRowsType,
|
for (Value gemmInputRow : gemmInputRows) {
|
||||||
gemmInputRows,
|
Value gemmRow = ONNXGemmOp::create(rewriter,
|
||||||
gemmB,
|
loc,
|
||||||
gemmC,
|
gemmOutputRowType,
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
gemmInputRow,
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
gemmB,
|
||||||
rewriter.getBoolAttr(false),
|
gemmC,
|
||||||
rewriter.getBoolAttr(false))
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
.getY();
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getBoolAttr(false),
|
||||||
|
rewriter.getBoolAttr(false))
|
||||||
|
.getY();
|
||||||
|
gemmRows.push_back(gemmRow);
|
||||||
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(convOp,
|
rewriter.replaceOp(convOp,
|
||||||
createCollectedConvOutput(ValueRange {gemmRows},
|
createCollectedConvOutput(gemmRows,
|
||||||
convOp.getType(),
|
convOp.getType(),
|
||||||
gemmOutType,
|
gemmOutType,
|
||||||
nhwcType,
|
nhwcType,
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
#include "mlir/Support/LogicalResult.h"
|
#include "mlir/Support/LogicalResult.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
@@ -66,66 +65,6 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
|
|||||||
ConversionPatternRewriter& rewriter) const override;
|
ConversionPatternRewriter& rewriter) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct GemmToSpatialComputeBatch : OpConversionPattern<ONNXGemmOp> {
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
|
|
||||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
|
||||||
ConversionPatternRewriter& rewriter) const override;
|
|
||||||
};
|
|
||||||
|
|
||||||
static SmallVector<Value> materializeBatchRowSlices(Value matrix,
|
|
||||||
RankedTensorType matrixType,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
const int64_t numRows = matrixType.getDimSize(0);
|
|
||||||
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
|
|
||||||
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
|
|
||||||
|
|
||||||
auto buildRowSlices = [&](Value matrixArg) {
|
|
||||||
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
|
|
||||||
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
|
|
||||||
};
|
|
||||||
|
|
||||||
auto cloneBatchInputChainIntoSliceCompute =
|
|
||||||
[&](Value rootInput, SmallVector<Operation*> chainOps, Value rootValue) -> SmallVector<Value> {
|
|
||||||
auto sliceCompute =
|
|
||||||
createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) {
|
|
||||||
Value transformedMatrix = input;
|
|
||||||
if (!chainOps.empty()) {
|
|
||||||
IRMapping mapper;
|
|
||||||
mapper.map(rootValue, input);
|
|
||||||
for (Operation* chainOp : chainOps)
|
|
||||||
rewriter.clone(*chainOp, mapper);
|
|
||||||
transformedMatrix = cast<Value>(mapper.lookup(matrix));
|
|
||||||
}
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix));
|
|
||||||
});
|
|
||||||
SmallVector<Value> rowSlices(sliceCompute->result_begin(), sliceCompute->result_end());
|
|
||||||
return rowSlices;
|
|
||||||
};
|
|
||||||
|
|
||||||
SmallVector<Operation*> chainOps;
|
|
||||||
Value rootValue = matrix;
|
|
||||||
while (Operation* definingOp = rootValue.getDefiningOp()) {
|
|
||||||
if (auto rootCompute = dyn_cast<spatial::SpatCompute>(definingOp)) {
|
|
||||||
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
|
|
||||||
return cloneBatchInputChainIntoSliceCompute(
|
|
||||||
rootCompute.getResult(cast<OpResult>(rootValue).getResultNumber()), reversedChainOps, rootValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (definingOp->getNumOperands() != 1)
|
|
||||||
break;
|
|
||||||
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
|
||||||
break;
|
|
||||||
|
|
||||||
chainOps.push_back(definingOp);
|
|
||||||
rootValue = definingOp->getOperand(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
return buildRowSlices(matrix);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||||
@@ -217,7 +156,8 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
|
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs));
|
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||||
});
|
});
|
||||||
|
|
||||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
@@ -373,116 +313,15 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
|
|
||||||
auto concatComputeOp =
|
auto concatComputeOp =
|
||||||
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
|
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
|
||||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs));
|
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
|
||||||
});
|
});
|
||||||
|
|
||||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
|
||||||
ONNXGemmOpAdaptor gemmOpAdaptor,
|
|
||||||
ConversionPatternRewriter& rewriter) const {
|
|
||||||
Location loc = gemmOp.getLoc();
|
|
||||||
Value a = gemmOpAdaptor.getA();
|
|
||||||
Value b = gemmOpAdaptor.getB();
|
|
||||||
Value c = gemmOpAdaptor.getC();
|
|
||||||
|
|
||||||
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
|
|
||||||
|
|
||||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
|
||||||
|
|
||||||
auto aType = cast<RankedTensorType>(a.getType());
|
|
||||||
auto bType = cast<RankedTensorType>(b.getType());
|
|
||||||
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
|
|
||||||
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && outType.hasStaticShape());
|
|
||||||
|
|
||||||
const int64_t numOutRows = aType.getDimSize(0);
|
|
||||||
if (numOutRows <= 1)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize
|
|
||||||
if (aType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue())
|
|
||||||
|| outType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue()))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
|
||||||
if (failed(scaledB))
|
|
||||||
return failure();
|
|
||||||
b = *scaledB;
|
|
||||||
bType = cast<RankedTensorType>(b.getType());
|
|
||||||
|
|
||||||
if (gemmOpAdaptor.getTransB()) {
|
|
||||||
auto bShape = bType.getShape();
|
|
||||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
|
||||||
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
|
||||||
bType = cast<RankedTensorType>(b.getType());
|
|
||||||
}
|
|
||||||
(void) bType;
|
|
||||||
|
|
||||||
Value sharedBias;
|
|
||||||
if (hasC) {
|
|
||||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
|
||||||
if (failed(scaledC))
|
|
||||||
return failure();
|
|
||||||
c = *scaledC;
|
|
||||||
auto cType = cast<RankedTensorType>(c.getType());
|
|
||||||
if (cType.getRank() == 1) {
|
|
||||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
|
||||||
c = tensor::ExpandShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
expandedType,
|
|
||||||
c,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0, 1}
|
|
||||||
});
|
|
||||||
cType = cast<RankedTensorType>(c.getType());
|
|
||||||
}
|
|
||||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
|
||||||
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
|
|
||||||
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
|
|
||||||
return failure();
|
|
||||||
if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
|
|
||||||
c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc);
|
|
||||||
sharedBias = c;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
|
||||||
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
|
|
||||||
|
|
||||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
|
||||||
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
|
|
||||||
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
|
|
||||||
|
|
||||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
|
||||||
loc,
|
|
||||||
TypeRange(resultTypes),
|
|
||||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
|
|
||||||
ValueRange(weights),
|
|
||||||
ValueRange(aSlices));
|
|
||||||
|
|
||||||
Block* body = rewriter.createBlock(
|
|
||||||
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
|
|
||||||
rewriter.setInsertionPointToEnd(body);
|
|
||||||
|
|
||||||
Value vmmResult = spatial::SpatWeightedVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
|
|
||||||
Value laneResult = vmmResult;
|
|
||||||
if (sharedBias)
|
|
||||||
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(batchOp);
|
|
||||||
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
|
|
||||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
|
|
||||||
});
|
|
||||||
|
|
||||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.insert<GemmToSpatialComputeBatch>(ctx, PatternBenefit(2));
|
|
||||||
patterns.insert<GemmToManyGemv>(ctx);
|
patterns.insert<GemmToManyGemv>(ctx);
|
||||||
patterns.insert<GemvToSpatialCompute>(ctx);
|
patterns.insert<GemvToSpatialCompute>(ctx);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
@@ -15,108 +14,7 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value extractBatchMatrix(Value value,
|
|
||||||
int64_t batchIndex,
|
|
||||||
int64_t batchSize,
|
|
||||||
int64_t rows,
|
|
||||||
int64_t cols,
|
|
||||||
PatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
auto type = cast<RankedTensorType>(value.getType());
|
|
||||||
if (type.getRank() == 2)
|
|
||||||
return value;
|
|
||||||
|
|
||||||
auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType());
|
|
||||||
SmallVector<OpFoldResult> offsets = {
|
|
||||||
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
||||||
SmallVector<OpFoldResult> sizes = {
|
|
||||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
||||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, value, offsets, sizes, strides);
|
|
||||||
|
|
||||||
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
|
||||||
return tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
matrixType,
|
|
||||||
slice,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0, 1},
|
|
||||||
{2}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool isConstantLikeOperand(Value value) {
|
|
||||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
|
||||||
|
|
||||||
while (auto* definingOp = value.getDefiningOp()) {
|
|
||||||
if (!visited.insert(definingOp).second)
|
|
||||||
return false;
|
|
||||||
if (definingOp->hasTrait<OpTrait::ConstantLike>())
|
|
||||||
return true;
|
|
||||||
|
|
||||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
|
||||||
value = extractSliceOp.getSource();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
|
||||||
value = expandShapeOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
|
||||||
value = collapseShapeOp.getSrc();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
|
||||||
value = transposeOp.getData();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
|
||||||
auto type = cast<RankedTensorType>(value.getType());
|
|
||||||
auto shape = type.getShape();
|
|
||||||
if (type.getRank() == 2) {
|
|
||||||
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
|
||||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
|
||||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
|
||||||
auto type = cast<RankedTensorType>(value.getType());
|
|
||||||
auto shape = type.getShape();
|
|
||||||
RankedTensorType transposedType;
|
|
||||||
SmallVector<int64_t> perm;
|
|
||||||
if (type.getRank() == 2) {
|
|
||||||
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
|
||||||
perm = {1, 0};
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
|
||||||
perm = {0, 2, 1};
|
|
||||||
}
|
|
||||||
|
|
||||||
auto transposeCompute =
|
|
||||||
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
|
|
||||||
Value transposed =
|
|
||||||
ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
|
||||||
});
|
|
||||||
return transposeCompute.getResult(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||||
@@ -126,113 +24,80 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||||
|| !outType.hasStaticShape())
|
|| !outType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
|
||||||
|| (outType.getRank() != 2 && outType.getRank() != 3))
|
|
||||||
return failure();
|
|
||||||
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
|
||||||
|| !haveStaticPositiveShape(outType.getShape()))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
|
const int64_t batch = rhsType.getDimSize(0);
|
||||||
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
|
const int64_t k = rhsType.getDimSize(1);
|
||||||
const int64_t batch = std::max(lhsBatch, rhsBatch);
|
const int64_t n = rhsType.getDimSize(2);
|
||||||
|
const int64_t m = lhsType.getDimSize(0);
|
||||||
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
|
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|
||||||
|
|| outType.getDimSize(2) != n)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
|
|
||||||
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
|
|
||||||
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
|
|
||||||
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
|
|
||||||
if (k != rhsK)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (outType.getRank() == 2) {
|
|
||||||
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
Location loc = matmulOp.getLoc();
|
Location loc = matmulOp.getLoc();
|
||||||
bool useTransposedForm = isConstantLikeOperand(matmulOp.getA()) && !isConstantLikeOperand(matmulOp.getB());
|
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
|
||||||
|
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
|
||||||
|
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
|
||||||
|
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
|
||||||
|
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
|
||||||
|
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
|
||||||
|
|
||||||
Value lhs = matmulOp.getA();
|
Value lhsTransposed =
|
||||||
Value rhs = matmulOp.getB();
|
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
|
||||||
int64_t lhsBatchForGemm = lhsBatch;
|
|
||||||
int64_t rhsBatchForGemm = rhsBatch;
|
|
||||||
int64_t gemmM = m;
|
|
||||||
int64_t gemmK = k;
|
|
||||||
int64_t gemmN = n;
|
|
||||||
if (useTransposedForm) {
|
|
||||||
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
|
||||||
lhsBatchForGemm = rhsBatch;
|
|
||||||
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
|
||||||
rhsBatchForGemm = lhsBatch;
|
|
||||||
gemmM = n;
|
|
||||||
gemmN = m;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
|
|
||||||
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
|
|
||||||
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
|
||||||
if (outType.getRank() == 2) {
|
SmallVector<Value> gemmRows;
|
||||||
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
gemmRows.reserve(batch * n);
|
||||||
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
|
||||||
Value gemmResult = ONNXGemmOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
gemmType,
|
|
||||||
lhsMatrix,
|
|
||||||
rhsMatrix,
|
|
||||||
none,
|
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
|
||||||
rewriter.getBoolAttr(false),
|
|
||||||
rewriter.getBoolAttr(false))
|
|
||||||
.getY();
|
|
||||||
if (useTransposedForm)
|
|
||||||
gemmResult = ONNXTransposeOp::create(rewriter, loc, outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}));
|
|
||||||
rewriter.replaceOp(matmulOp, gemmResult);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value> batchResults;
|
|
||||||
batchResults.reserve(batch);
|
|
||||||
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
||||||
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
|
||||||
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
SmallVector<OpFoldResult> offsets = {
|
||||||
Value gemmResult = ONNXGemmOp::create(rewriter,
|
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
|
||||||
loc,
|
SmallVector<OpFoldResult> sizes = {
|
||||||
gemmType,
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
|
||||||
lhsMatrix,
|
SmallVector<OpFoldResult> strides = {
|
||||||
rhsMatrix,
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
none,
|
Value rhsSlice =
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
|
||||||
rewriter.getBoolAttr(false),
|
loc,
|
||||||
rewriter.getBoolAttr(false))
|
rhsRowType,
|
||||||
.getY();
|
rhsSlice,
|
||||||
if (useTransposedForm)
|
SmallVector<ReassociationIndices> {
|
||||||
gemmResult = ONNXTransposeOp::create(
|
{0},
|
||||||
rewriter,
|
{1, 2}
|
||||||
loc,
|
});
|
||||||
RankedTensorType::get({m, n}, outType.getElementType()),
|
|
||||||
gemmResult,
|
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||||
rewriter.getI64ArrayAttr({1, 0}));
|
loc,
|
||||||
batchResults.push_back(tensor::ExpandShapeOp::create(rewriter,
|
gemmRowType,
|
||||||
loc,
|
rhsRow,
|
||||||
batchedOutType,
|
lhsTransposed,
|
||||||
gemmResult,
|
none,
|
||||||
SmallVector<ReassociationIndices> {
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
{0, 1},
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
{2}
|
rewriter.getBoolAttr(false),
|
||||||
}));
|
rewriter.getBoolAttr(false));
|
||||||
|
gemmRows.push_back(gemmOp.getY());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Value result = createSpatConcat(rewriter, loc, /*axis=*/0, batchResults);
|
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) {
|
||||||
|
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||||
|
});
|
||||||
|
|
||||||
|
Value gemmOut = concatComputeOp.getResult(0);
|
||||||
|
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
gemmExpandedType,
|
||||||
|
gemmOut,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||||
|
|
||||||
rewriter.replaceOp(matmulOp, result);
|
rewriter.replaceOp(matmulOp, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -241,7 +106,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.insert<MatMulToGemm>(ctx);
|
patterns.insert<MatMulRank3ToGemm>(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -100,7 +100,8 @@ static Value buildReduceMeanKeepdims(Value input,
|
|||||||
for (Value slice : slices)
|
for (Value slice : slices)
|
||||||
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
||||||
|
|
||||||
return createSpatConcat(rewriter, loc, axis, reducedSlices);
|
return reducedSlices.size() == 1 ? reducedSlices.front()
|
||||||
|
: tensor::ConcatOp::create(rewriter, loc, axis, reducedSlices).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value squeezeReducedAxes(Value keepdimsValue,
|
static Value squeezeReducedAxes(Value keepdimsValue,
|
||||||
|
|||||||
@@ -33,7 +33,9 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
|
|||||||
|
|
||||||
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
||||||
assert(!values.empty() && "Expected at least one value to concatenate.");
|
assert(!values.empty() && "Expected at least one value to concatenate.");
|
||||||
return createSpatConcat(rewriter, loc, axis, values);
|
if (values.size() == 1)
|
||||||
|
return values.front();
|
||||||
|
return tensor::ConcatOp::create(rewriter, loc, axis, values);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||||
|
|||||||
@@ -47,7 +47,8 @@ buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRe
|
|||||||
for (Value slice : slices)
|
for (Value slice : slices)
|
||||||
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
||||||
|
|
||||||
return createSpatConcat(rewriter, loc, axis, rebuiltSlices);
|
return rebuiltSlices.size() == 1 ? rebuiltSlices.front()
|
||||||
|
: tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -18,7 +17,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
|||||||
auto inputs = adaptor.getInputs();
|
auto inputs = adaptor.getInputs();
|
||||||
int64_t axis = adaptor.getAxis();
|
int64_t axis = adaptor.getAxis();
|
||||||
|
|
||||||
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
|
rewriter.replaceOpWithNewOp<tensor::ConcatOp>(maxpoolOp, axis, inputs);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ static Value concatGatherSlices(Value data,
|
|||||||
}
|
}
|
||||||
if (slices.empty())
|
if (slices.empty())
|
||||||
return {};
|
return {};
|
||||||
return createSpatConcat(rewriter, loc, axis, slices);
|
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
@@ -130,7 +130,9 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
|
|||||||
return failure();
|
return failure();
|
||||||
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
|
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
|
||||||
}
|
}
|
||||||
result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows);
|
result = rows.size() == 1
|
||||||
|
? rows.front()
|
||||||
|
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult();
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ static Value buildNearestResize(Value input,
|
|||||||
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
|
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
|
||||||
}
|
}
|
||||||
|
|
||||||
return createSpatConcat(rewriter, loc, axis, slices);
|
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||||
|
|||||||
@@ -23,10 +23,7 @@ static Value extractSliceAt(
|
|||||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||||
sizes[axis] = rewriter.getIndexAttr(size);
|
sizes[axis] = rewriter.getIndexAttr(size);
|
||||||
SmallVector<int64_t> resultShape(inputType.getShape());
|
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||||
resultShape[axis] = size;
|
|
||||||
auto resultType = RankedTensorType::get(resultShape, inputType.getElementType());
|
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Split : OpConversionPattern<ONNXSplitOp> {
|
struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||||
@@ -52,7 +49,12 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
|
|||||||
if (!resultType || !resultType.hasStaticShape())
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
int64_t sliceSize = resultType.getShape()[axis];
|
int64_t sliceSize = resultType.getShape()[axis];
|
||||||
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
|
auto computeOp =
|
||||||
|
createSpatCompute<1>(rewriter, splitOp.getLoc(), TypeRange {resultType}, {}, adaptor.getInput(), [&](Value x) {
|
||||||
|
Value output = extractSliceAt(x, axis, offset, sliceSize, rewriter, splitOp.getLoc());
|
||||||
|
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), output);
|
||||||
|
});
|
||||||
|
outputs.push_back(computeOp.getResult(0));
|
||||||
offset += sliceSize;
|
offset += sliceSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ add_public_tablegen_target(SpatialToPimIncGen)
|
|||||||
add_pim_library(OMSpatialToPim
|
add_pim_library(OMSpatialToPim
|
||||||
SpatialToPimPass.cpp
|
SpatialToPimPass.cpp
|
||||||
Common.cpp
|
Common.cpp
|
||||||
Patterns.cpp
|
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -7,12 +7,23 @@
|
|||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
|
||||||
#include "Common.hpp"
|
#include "Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) {
|
||||||
|
auto attr = op->getAttrOfType<IntegerAttr>(attrName);
|
||||||
|
assert(attr && "required precomputed channel attr is missing");
|
||||||
|
return IntegerAttr::get(builder.getI32Type(), attr.getInt());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
|
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
|
||||||
/*
|
/*
|
||||||
EXAMPLE RUN:
|
EXAMPLE RUN:
|
||||||
@@ -63,6 +74,37 @@ IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
|||||||
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) {
|
||||||
|
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||||
|
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
||||||
|
return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) {
|
||||||
|
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||||
|
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
|
||||||
|
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
|
||||||
|
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||||
|
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
|
||||||
|
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
|
||||||
|
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value createPimReceiveFromSpatialChannel(
|
||||||
|
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
|
||||||
|
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
|
||||||
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
||||||
|
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
|
||||||
|
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||||
|
.getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
Operation* getEarliestUserWithinBlock(mlir::Value value) {
|
||||||
auto users = value.getUsers();
|
auto users = value.getUsers();
|
||||||
|
|
||||||
@@ -85,16 +127,6 @@ SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
|
|||||||
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
||||||
}
|
}
|
||||||
|
|
||||||
bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
|
|
||||||
for (Operation* user : value.getUsers()) {
|
|
||||||
if (user->getBlock() != operation->getBlock())
|
|
||||||
return true;
|
|
||||||
if (operation->isBeforeInBlock(user))
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
||||||
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
||||||
mlir::Value result = operation->getResult(0);
|
mlir::Value result = operation->getResult(0);
|
||||||
@@ -102,9 +134,8 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter,
|
|||||||
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
||||||
|
|
||||||
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
|
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
|
||||||
auto validOperands = make_filter_range(operands, [operation, resultType](mlir::Value operand) {
|
auto validOperands =
|
||||||
return operand.getType() == resultType && !hasLaterUserInBlock(operand, operation);
|
make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; });
|
||||||
});
|
|
||||||
auto bestOperand = validOperands.begin();
|
auto bestOperand = validOperands.begin();
|
||||||
|
|
||||||
if (bestOperand != validOperands.end())
|
if (bestOperand != validOperands.end())
|
||||||
|
|||||||
@@ -2,10 +2,16 @@
|
|||||||
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id";
|
||||||
|
inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
||||||
* its static tensor input.
|
* its static tensor input.
|
||||||
@@ -24,6 +30,17 @@ size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
|||||||
|
|
||||||
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||||
|
|
||||||
|
mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||||
|
|
||||||
|
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
|
||||||
|
|
||||||
|
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
|
||||||
|
|
||||||
|
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
|
||||||
|
|
||||||
|
mlir::Value createPimReceiveFromSpatialChannel(
|
||||||
|
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
size_t rangeLength(const mlir::iterator_range<T> range) {
|
size_t rangeLength(const mlir::iterator_range<T> range) {
|
||||||
return std::distance(range.begin(), range.end());
|
return std::distance(range.begin(), range.end());
|
||||||
|
|||||||
@@ -1,385 +0,0 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
#include "mlir/Support/LLVM.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
#include "llvm/Support/ErrorHandling.h"
|
|
||||||
#include "llvm/Support/LogicalResult.h"
|
|
||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
|
|
||||||
Location loc = extractSliceOp.getLoc();
|
|
||||||
|
|
||||||
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
for (auto& uses : extractSliceOp->getUses()) {
|
|
||||||
if (isa<spatial::SpatCompute>(uses.getOwner())) {
|
|
||||||
auto spatCompute = cast<spatial::SpatCompute>(uses.getOwner());
|
|
||||||
if (spatCompute.getInputs().empty())
|
|
||||||
return failure();
|
|
||||||
if (uses.getOperandNumber() < spatCompute.getInputs().getBeginOperandIndex())
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::DenseMap<Operation*, Value> mapSpatToExtract;
|
|
||||||
|
|
||||||
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
|
||||||
|
|
||||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
|
|
||||||
auto BBArgIndex = uses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
|
||||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
|
||||||
|
|
||||||
if (BBArgValue.use_empty())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
||||||
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
|
||||||
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
|
||||||
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.startOpModification(spatCompute.getOperation());
|
|
||||||
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatCompute.getOperation()]);
|
|
||||||
spatCompute.getInputsMutable().erase(BBArgIndex);
|
|
||||||
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
|
||||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
|
||||||
}
|
|
||||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
|
||||||
auto BBArgIndex = uses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
|
||||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
|
||||||
|
|
||||||
if (BBArgValue.use_empty())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
||||||
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
|
||||||
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
|
||||||
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
|
||||||
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatComputeBatch.getOperation()]);
|
|
||||||
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
|
||||||
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
|
||||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
{
|
|
||||||
if (auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
||||||
if (!mapSpatToExtract.contains(spatCompute.getOperation())) {
|
|
||||||
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
|
||||||
mapSpatToExtract.insert({spatCompute.getOperation(), newExtractSlice->getResult(0)});
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.startOpModification(spatCompute.getOperation());
|
|
||||||
uses.set(mapSpatToExtract[spatCompute.getOperation()]);
|
|
||||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
|
||||||
}
|
|
||||||
else if (auto spatComputeBatch = uses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
||||||
if (!mapSpatToExtract.contains(spatComputeBatch.getOperation())) {
|
|
||||||
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
|
||||||
mapSpatToExtract.insert({spatComputeBatch.getOperation(), newExtractSlice->getResult(0)});
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
|
||||||
uses.set(mapSpatToExtract[spatComputeBatch.getOperation()]);
|
|
||||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.eraseOp(extractSliceOp);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
|
|
||||||
static int i = 0;
|
|
||||||
Location loc = constantOp.getLoc();
|
|
||||||
|
|
||||||
if (hasWeightAlways(constantOp))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (!isa<func::FuncOp>(constantOp->getParentOp()))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
|
|
||||||
if (isa<spatial::SpatCompute>(op))
|
|
||||||
return false;
|
|
||||||
if (isa<func::FuncOp>(op->getParentOp()))
|
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
}))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
|
|
||||||
|
|
||||||
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
|
|
||||||
|
|
||||||
if (constRankedTensorType) {
|
|
||||||
mlir::MemRefType memRefType =
|
|
||||||
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
|
|
||||||
std::string argName = "const_" + std::to_string(i++);
|
|
||||||
memref::GlobalOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
rewriter.getStringAttr(argName),
|
|
||||||
rewriter.getStringAttr("private"),
|
|
||||||
TypeAttr::get(memRefType),
|
|
||||||
constantOp.getValueAttr(),
|
|
||||||
rewriter.getUnitAttr(),
|
|
||||||
{});
|
|
||||||
|
|
||||||
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
|
||||||
|
|
||||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
|
||||||
auto constUsers = constUses.getOwner();
|
|
||||||
|
|
||||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
|
||||||
|
|
||||||
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
|
||||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
||||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
||||||
auto toTensor = bufferization::ToTensorOp::create(
|
|
||||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
||||||
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.startOpModification(spatCompute.getOperation());
|
|
||||||
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatCompute.getOperation()]);
|
|
||||||
spatCompute.getInputsMutable().erase(BBArgIndex);
|
|
||||||
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
|
||||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
|
||||||
}
|
|
||||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
|
||||||
|
|
||||||
auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
|
||||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
||||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
||||||
auto toTensor = bufferization::ToTensorOp::create(
|
|
||||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
||||||
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
|
||||||
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
|
||||||
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
|
||||||
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
|
||||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
{
|
|
||||||
|
|
||||||
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
||||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
||||||
auto toTensor = bufferization::ToTensorOp::create(
|
|
||||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
||||||
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.startOpModification(spatCompute.getOperation());
|
|
||||||
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
|
|
||||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
|
||||||
}
|
|
||||||
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
||||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
||||||
auto toTensor = bufferization::ToTensorOp::create(
|
|
||||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
||||||
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
|
||||||
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
|
||||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (constantOp.getType().isIntOrIndexOrFloat()) {
|
|
||||||
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
|
||||||
|
|
||||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
|
||||||
auto constUsers = constUses.getOwner();
|
|
||||||
|
|
||||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
|
||||||
|
|
||||||
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
|
||||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
||||||
auto newConst = rewriter.clone(*constantOp);
|
|
||||||
|
|
||||||
rewriter.startOpModification(spatCompute.getOperation());
|
|
||||||
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
|
|
||||||
spatCompute.getInputsMutable().erase(BBArgIndex);
|
|
||||||
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
|
||||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
|
||||||
}
|
|
||||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
|
||||||
|
|
||||||
auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
|
||||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
||||||
auto newConst = rewriter.clone(*constantOp);
|
|
||||||
|
|
||||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
|
||||||
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
|
|
||||||
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
|
||||||
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
|
||||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
|
||||||
if (!mapSpatComputeToConst.contains(parent)) {
|
|
||||||
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
|
||||||
auto newConst = rewriter.clone(*constantOp);
|
|
||||||
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
|
|
||||||
}
|
|
||||||
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
|
|
||||||
assert(batchParent && "Global Constant used direcly not within a compute");
|
|
||||||
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
|
|
||||||
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
|
|
||||||
auto newConst = rewriter.clone(*constantOp);
|
|
||||||
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
|
|
||||||
}
|
|
||||||
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto parent = constantOp->getParentOp();
|
|
||||||
rewriter.eraseOp(constantOp);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override {
|
|
||||||
|
|
||||||
if (funcOp.getArguments().empty())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (llvm::all_of(funcOp.getArguments(),
|
|
||||||
[](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); }))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
Location loc = funcOp.getLoc();
|
|
||||||
|
|
||||||
for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) {
|
|
||||||
if (arg.getUses().empty())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(funcOp.getOperation());
|
|
||||||
|
|
||||||
assert(isa<mlir::RankedTensorType>(arg.getType()));
|
|
||||||
|
|
||||||
auto argRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(arg.getType());
|
|
||||||
mlir::MemRefType memRefType =
|
|
||||||
mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType());
|
|
||||||
|
|
||||||
std::string argName = "arg_" + std::to_string(index);
|
|
||||||
|
|
||||||
memref::GlobalOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
rewriter.getStringAttr(argName),
|
|
||||||
rewriter.getStringAttr("private"),
|
|
||||||
TypeAttr::get(memRefType),
|
|
||||||
{},
|
|
||||||
{},
|
|
||||||
{});
|
|
||||||
|
|
||||||
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
|
||||||
auto argUser = argUses.getOwner();
|
|
||||||
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
|
|
||||||
auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
|
||||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
||||||
auto toTensor = bufferization::ToTensorOp::create(
|
|
||||||
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
||||||
|
|
||||||
rewriter.startOpModification(spatCompute.getOperation());
|
|
||||||
BBArgValue.replaceAllUsesWith(toTensor);
|
|
||||||
spatCompute.getInputsMutable().erase(BBArgIndex);
|
|
||||||
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
|
||||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
|
||||||
}
|
|
||||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(argUser)) {
|
|
||||||
auto BBArgIndex = argUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex();
|
|
||||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
|
||||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
||||||
auto toTensor = bufferization::ToTensorOp::create(
|
|
||||||
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
||||||
|
|
||||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
|
||||||
BBArgValue.replaceAllUsesWith(toTensor);
|
|
||||||
spatComputeBatch.getInputsMutable().erase(BBArgIndex);
|
|
||||||
spatComputeBatch.getBody().front().eraseArgument(BBArgIndex);
|
|
||||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
rewriter.setInsertionPoint(argUser);
|
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
|
||||||
auto toTensor = bufferization::ToTensorOp::create(
|
|
||||||
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
|
||||||
rewriter.startOpModification(argUser);
|
|
||||||
argUses.set(toTensor);
|
|
||||||
rewriter.finalizeOpModification(argUser);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
|
|
||||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
|
||||||
patterns.getContext());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
void populateGlobalTensorToMemrefPatterns(mlir::RewritePatternSet& patterns);
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -9,6 +9,17 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
|||||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||||
#endif // OP_BASE
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
def HasSpatialChannelSourceCoreIdAttr: Constraint<
|
||||||
|
CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">,
|
||||||
|
"spatial channel has precomputed source core id">;
|
||||||
|
|
||||||
|
def HasSpatialChannelTargetCoreIdAttr: Constraint<
|
||||||
|
CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">,
|
||||||
|
"spatial channel has precomputed target core id">;
|
||||||
|
|
||||||
|
def createPimReceiveFromSpatialChannelValue: NativeCodeCall<
|
||||||
|
"onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">;
|
||||||
|
|
||||||
def onnxToPimTranspose : Pat<
|
def onnxToPimTranspose : Pat<
|
||||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
||||||
(PimTransposeOp $data, $perms,
|
(PimTransposeOp $data, $perms,
|
||||||
@@ -69,4 +80,18 @@ def spatToPimVSoftmax : Pat<
|
|||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
def spatChannelSendToPimSend : Pat<
|
||||||
|
(SpatChannelSendOp $channel, $input),
|
||||||
|
(PimSendOp $input,
|
||||||
|
(NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input),
|
||||||
|
(NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)),
|
||||||
|
[(HasSpatialChannelTargetCoreIdAttr $channel)]
|
||||||
|
>;
|
||||||
|
|
||||||
|
def spatChannelReceiveToPimReceive : Pat<
|
||||||
|
(SpatChannelReceiveOp:$srcOpRes $channel),
|
||||||
|
(createPimReceiveFromSpatialChannelValue $srcOpRes, $channel),
|
||||||
|
[(HasSpatialChannelSourceCoreIdAttr $channel)]
|
||||||
|
>;
|
||||||
|
|
||||||
#endif // SPATIAL_TO_PIM
|
#endif // SPATIAL_TO_PIM
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -24,7 +24,7 @@ def PimTensor :
|
|||||||
// Execution
|
// Execution
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
|
def PimCoreOp : PimOp<"core", [SingleBlock]> {
|
||||||
let summary = "Execute a block on a PIM core";
|
let summary = "Execute a block on a PIM core";
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
@@ -39,22 +39,6 @@ def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments]> {
|
|
||||||
let summary = "Execute equivalent batched core bodies";
|
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
I32Attr:$laneCount,
|
|
||||||
Variadic<PimTensor>:$weights,
|
|
||||||
Variadic<PimTensor>:$inputs
|
|
||||||
);
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`lanes` $laneCount `(` $weights `)` `[` $inputs `]` attr-dict regions `:` type($weights) `[` type($inputs) `]` `->` `(` `)`
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimHaltOp : PimOp<"halt", [Terminator]> {
|
def PimHaltOp : PimOp<"halt", [Terminator]> {
|
||||||
let summary = "Halt execution of the core";
|
let summary = "Halt execution of the core";
|
||||||
|
|
||||||
@@ -81,20 +65,6 @@ def PimSendOp : PimOp<"send", []> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimSendBatchOp : PimOp<"send_batch", []> {
|
|
||||||
let summary = "Send a per-lane tensor to target cores from a batched core";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor:$input,
|
|
||||||
I32Attr:$size,
|
|
||||||
DenseI32ArrayAttr:$targetCoreIds
|
|
||||||
);
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||||
let summary = "Receive a tensor from another core";
|
let summary = "Receive a tensor from another core";
|
||||||
|
|
||||||
@@ -119,30 +89,6 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> {
|
|
||||||
let summary = "Receive per-lane tensors from source cores into a batched core";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor:$outputBuffer,
|
|
||||||
I32Attr:$size,
|
|
||||||
DenseI32ArrayAttr:$sourceCoreIds
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getOutputBufferMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||||
let summary = "Copy a memory region from host memory into device memory";
|
let summary = "Copy a memory region from host memory into device memory";
|
||||||
|
|
||||||
@@ -169,32 +115,6 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def PimMemCopyHostToDevBatchOp : PimOp<"memcp_hd_batch", [DestinationStyleOpInterface]> {
|
|
||||||
let summary = "Copy a per-lane tensor from host memory into device memory inside a batched core";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
PimTensor:$deviceTarget,
|
|
||||||
PimTensor:$hostSource,
|
|
||||||
I32Attr:$deviceTargetOffset,
|
|
||||||
I32Attr:$hostSourceOffset,
|
|
||||||
I32Attr:$size
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
PimTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
let extraClassDeclaration = [{
|
|
||||||
mlir::MutableOperandRange getDpsInitsMutable() {
|
|
||||||
return getDeviceTargetMutable();
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
|
||||||
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||||
let summary = "Copy a memory region from device memory into host memory";
|
let summary = "Copy a memory region from device memory into host memory";
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
|
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
|
||||||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
#include "OpBufferizationInterfaces.hpp"
|
#include "OpBufferizationInterfaces.hpp"
|
||||||
@@ -66,32 +65,6 @@ struct MemCopyHostToDevOpInterface
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct MemCopyHostToDevBatchOpInterface
|
|
||||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevBatchOpInterface, PimMemCopyHostToDevBatchOp> {
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevBatchOp>(op);
|
|
||||||
auto deviceTargetOpt = getBuffer(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state);
|
|
||||||
if (failed(deviceTargetOpt))
|
|
||||||
return failure();
|
|
||||||
auto hostSourceOpt = getBuffer(rewriter, memCopyHostToDevOp.getHostSource(), options, state);
|
|
||||||
if (failed(hostSourceOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevBatchOp>(rewriter,
|
|
||||||
memCopyHostToDevOp,
|
|
||||||
deviceTargetOpt->getType(),
|
|
||||||
*deviceTargetOpt,
|
|
||||||
*hostSourceOpt,
|
|
||||||
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
|
|
||||||
memCopyHostToDevOp.getHostSourceOffsetAttr(),
|
|
||||||
memCopyHostToDevOp.getSizeAttr());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct MemCopyDevToHostOpInterface
|
struct MemCopyDevToHostOpInterface
|
||||||
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
|
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
|
||||||
LogicalResult bufferize(Operation* op,
|
LogicalResult bufferize(Operation* op,
|
||||||
@@ -149,127 +122,6 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ReceiveBatchOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveBatchOpInterface, PimReceiveBatchOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto receiveOp = cast<PimReceiveBatchOp>(op);
|
|
||||||
auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state);
|
|
||||||
if (failed(outputBufferOpt))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimReceiveBatchOp>(rewriter,
|
|
||||||
op,
|
|
||||||
outputBufferOpt->getType(),
|
|
||||||
*outputBufferOpt,
|
|
||||||
receiveOp.getSizeAttr(),
|
|
||||||
receiveOp.getSourceCoreIdsAttr());
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOpInterface, PimCoreBatchOp> {
|
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
|
|
||||||
auto coreBatchOp = cast<PimCoreBatchOp>(op);
|
|
||||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
|
||||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
|
||||||
return {};
|
|
||||||
|
|
||||||
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber();
|
|
||||||
return {{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}};
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isWritable(Operation* op, Value value, const AnalysisState& state) const {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<BufferLikeType>
|
|
||||||
getBufferType(Operation* op,
|
|
||||||
Value value,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
const BufferizationState& state,
|
|
||||||
SmallVector<Value>& invocationStack) const {
|
|
||||||
auto coreBatchOp = cast<PimCoreBatchOp>(op);
|
|
||||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
|
||||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
Value tiedInput = coreBatchOp.getInputs()[bbArg.getArgNumber()];
|
|
||||||
if (auto memRefType = dyn_cast<BufferLikeType>(tiedInput.getType()))
|
|
||||||
return memRefType;
|
|
||||||
|
|
||||||
return bufferization::getBufferType(tiedInput, options, state, invocationStack);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult bufferize(Operation* op,
|
|
||||||
RewriterBase& rewriter,
|
|
||||||
const BufferizationOptions& options,
|
|
||||||
BufferizationState& state) const {
|
|
||||||
auto coreBatchOp = cast<PimCoreBatchOp>(op);
|
|
||||||
|
|
||||||
SmallVector<Value> weights;
|
|
||||||
SmallVector<Value> inputs;
|
|
||||||
weights.reserve(coreBatchOp.getWeights().size());
|
|
||||||
inputs.reserve(coreBatchOp.getInputs().size());
|
|
||||||
|
|
||||||
for (Value weight : coreBatchOp.getWeights()) {
|
|
||||||
if (isa<TensorType>(weight.getType())) {
|
|
||||||
auto weightOpt = getBuffer(rewriter, weight, options, state);
|
|
||||||
if (failed(weightOpt))
|
|
||||||
return failure();
|
|
||||||
weights.push_back(*weightOpt);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
weights.push_back(weight);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (Value input : coreBatchOp.getInputs()) {
|
|
||||||
if (isa<TensorType>(input.getType())) {
|
|
||||||
auto inputOpt = getBuffer(rewriter, input, options, state);
|
|
||||||
if (failed(inputOpt))
|
|
||||||
return failure();
|
|
||||||
inputs.push_back(*inputOpt);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
inputs.push_back(input);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(coreBatchOp);
|
|
||||||
auto newOp = PimCoreBatchOp::create(
|
|
||||||
rewriter, coreBatchOp.getLoc(), coreBatchOp.getLaneCountAttr(), ValueRange(weights), ValueRange(inputs));
|
|
||||||
newOp.getProperties().setOperandSegmentSizes({static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
|
||||||
if (auto coreIdsAttr = coreBatchOp->getAttr(onnx_mlir::kCoreIdAttrName))
|
|
||||||
newOp->setAttr(onnx_mlir::kCoreIdAttrName, coreIdsAttr);
|
|
||||||
|
|
||||||
rewriter.inlineRegionBefore(coreBatchOp.getBody(), newOp.getBody(), newOp.getBody().begin());
|
|
||||||
for (Block& block : newOp.getBody())
|
|
||||||
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state)))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
rewriter.eraseOp(coreBatchOp);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel<TransposeOpInterface, PimTransposeOp> {
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
@@ -326,10 +178,8 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||||
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
|
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -353,10 +203,8 @@ struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface,
|
|||||||
if (failed(outputBufferOpt))
|
if (failed(outputBufferOpt))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
|
||||||
|
|
||||||
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
replaceOpWithNewBufferizedOp<PimMVMOp>(
|
||||||
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
|
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -435,11 +283,8 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpI
|
|||||||
|
|
||||||
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||||
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
registry.addExtension(+[](MLIRContext* ctx, PimDialect* dialect) {
|
||||||
PimCoreBatchOp::attachInterface<CoreBatchOpInterface>(*ctx);
|
|
||||||
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
PimReceiveOp::attachInterface<ReceiveOpInterface>(*ctx);
|
||||||
PimReceiveBatchOp::attachInterface<ReceiveBatchOpInterface>(*ctx);
|
|
||||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||||
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
|
|
||||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||||
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
||||||
|
|||||||
@@ -3,17 +3,12 @@
|
|||||||
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/IR/Threading.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
#include "llvm/Support/Debug.h"
|
|
||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Compiler/PimCodeGen.hpp"
|
#include "Compiler/PimCodeGen.hpp"
|
||||||
#include "Dialect/Pim/PimOps.hpp"
|
#include "Dialect/Pim/PimOps.hpp"
|
||||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
@@ -45,44 +40,14 @@ private:
|
|||||||
|
|
||||||
void PimBufferizationPass::runOnOperation() {
|
void PimBufferizationPass::runOnOperation() {
|
||||||
auto moduleOp = getOperation();
|
auto moduleOp = getOperation();
|
||||||
// Refactor this into a function
|
|
||||||
{
|
|
||||||
auto funcOp = getPimEntryFunc(moduleOp);
|
|
||||||
|
|
||||||
auto coreOps = llvm::to_vector(funcOp->getOps<pim::PimCoreOp>());
|
// One-Shot-Bufferization
|
||||||
MLIRContext* ctx = moduleOp.getContext();
|
bufferization::OneShotBufferizationOptions options;
|
||||||
// failableParallelForEach will run the lambda in parallel and stop if any thread fails
|
options.allowUnknownOps = true;
|
||||||
LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](pim::PimCoreOp coreOp) {
|
bufferization::BufferizationState state;
|
||||||
// Again, allocate state LOCALLY per thread/function
|
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
|
||||||
bufferization::OneShotBufferizationOptions options;
|
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
|
||||||
options.allowUnknownOps = true;
|
signalPassFailure();
|
||||||
bufferization::BufferizationState state;
|
|
||||||
if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) {
|
|
||||||
coreOp.emitError("Failed to bufferize PIM and Spatial ops");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
});
|
|
||||||
|
|
||||||
if (failed(result)) {
|
|
||||||
moduleOp.emitError("Failed to bufferize-parallel PIM and Spatial ops");
|
|
||||||
signalPassFailure();
|
|
||||||
}
|
|
||||||
|
|
||||||
funcOp->walk([&](bufferization::ToTensorOp toTensorOp) {
|
|
||||||
if (llvm::isa_and_present<pim::PimCoreOp>(toTensorOp->getParentOp()))
|
|
||||||
toTensorOp->setAttr("restrict", UnitAttr::get(ctx));
|
|
||||||
});
|
|
||||||
|
|
||||||
// One-Shot-Bufferization
|
|
||||||
bufferization::OneShotBufferizationOptions options;
|
|
||||||
options.allowUnknownOps = true;
|
|
||||||
bufferization::BufferizationState state;
|
|
||||||
|
|
||||||
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
|
|
||||||
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
|
|
||||||
signalPassFailure();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MLIRContext* ctx = moduleOp.getContext();
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
@@ -92,18 +57,7 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
RewritePatternSet patterns(ctx);
|
RewritePatternSet patterns(ctx);
|
||||||
populateWithGenerated(patterns);
|
populateWithGenerated(patterns);
|
||||||
|
|
||||||
// Only convert memref.copy → pim.memcp inside pim.core / pim.core_batch bodies.
|
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||||
// Host-level copies (e.g. from split/slice ops) must remain as memref.copy for CPU lowering.
|
|
||||||
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
|
|
||||||
bool hasFailed = false;
|
|
||||||
moduleOp.walk<WalkOrder::PreOrder>([&](Operation* op) {
|
|
||||||
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
|
|
||||||
return WalkResult::advance();
|
|
||||||
if (failed(applyPartialConversion(op, target, frozenPatterns)))
|
|
||||||
hasFailed = true;
|
|
||||||
return WalkResult::skip();
|
|
||||||
});
|
|
||||||
if (hasFailed) {
|
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -139,9 +93,11 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||||
auto markWeights = [&](Operation* op) {
|
funcOp.walk([&](PimCoreOp coreOp) {
|
||||||
walkPimMvmVmmWeightUses(op, [&](OpOperand& weightUse) {
|
auto annotateWeight = [&](unsigned weightIndex) {
|
||||||
Value weight = weightUse.get();
|
if (weightIndex >= coreOp.getWeights().size())
|
||||||
|
return;
|
||||||
|
Value weight = coreOp.getWeights()[weightIndex];
|
||||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||||
if (!getGlobalOp)
|
if (!getGlobalOp)
|
||||||
return;
|
return;
|
||||||
@@ -149,11 +105,11 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
|
|||||||
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
||||||
markWeightAlways(getGlobalOp);
|
markWeightAlways(getGlobalOp);
|
||||||
markWeightAlways(globalMemrefOp);
|
markWeightAlways(globalMemrefOp);
|
||||||
});
|
};
|
||||||
};
|
|
||||||
|
|
||||||
funcOp.walk([&](PimCoreOp coreOp) { markWeights(coreOp); });
|
coreOp.walk([&](PimMVMOp mvmOp) { annotateWeight(mvmOp.getWeightIndex()); });
|
||||||
funcOp.walk([&](PimCoreBatchOp coreBatchOp) { markWeights(coreBatchOp); });
|
coreOp.walk([&](PimVMMOp vmmOp) { annotateWeight(vmmOp.getWeightIndex()); });
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
std::unique_ptr<Pass> createPimBufferizationPass() { return std::make_unique<PimBufferizationPass>(); }
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ add_onnx_mlir_dialect(Spatial spat)
|
|||||||
add_onnx_mlir_dialect_doc(spat Spatial.td)
|
add_onnx_mlir_dialect_doc(spat Spatial.td)
|
||||||
|
|
||||||
add_pim_library(SpatialOps
|
add_pim_library(SpatialOps
|
||||||
Channels.cpp
|
|
||||||
SpatialOps.cpp
|
SpatialOps.cpp
|
||||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||||
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||||
|
|||||||
@@ -1,120 +0,0 @@
|
|||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/IR/Diagnostics.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace onnx_mlir::spatial {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
static Channels::ChannelId getChannelId(SpatChannelSendOp sendOp) { return sendOp.getChannelId(); }
|
|
||||||
|
|
||||||
static Channels::ChannelId getChannelId(SpatChannelReceiveOp receiveOp) { return receiveOp.getChannelId(); }
|
|
||||||
|
|
||||||
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
|
|
||||||
if (!endpoints.send || !endpoints.receive)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
if (endpoints.send.getSourceCoreId() != endpoints.receive.getSourceCoreId()) {
|
|
||||||
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (endpoints.send.getTargetCoreId() != endpoints.receive.getTargetCoreId()) {
|
|
||||||
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (endpoints.send.getInput().getType() != endpoints.receive.getOutput().getType()) {
|
|
||||||
endpoints.send.emitOpError("input type does not match paired spat.channel_receive result type");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
Channels::Channels(func::FuncOp funcOp) {
|
|
||||||
if (!funcOp)
|
|
||||||
return;
|
|
||||||
|
|
||||||
funcOp.walk([&](SpatChannelSendOp sendOp) { insertSend(sendOp); });
|
|
||||||
funcOp.walk([&](SpatChannelReceiveOp receiveOp) { insertReceive(receiveOp); });
|
|
||||||
}
|
|
||||||
|
|
||||||
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
|
|
||||||
|
|
||||||
void Channels::insertSend(SpatChannelSendOp sendOp) {
|
|
||||||
ChannelId channelId = getChannelId(sendOp);
|
|
||||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
|
||||||
endpoints[channelId].send = sendOp;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
|
|
||||||
ChannelId channelId = getChannelId(receiveOp);
|
|
||||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
|
||||||
endpoints[channelId].receive = receiveOp;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
|
||||||
ChannelId channelId = getChannelId(sendOp);
|
|
||||||
auto it = endpoints.find(channelId);
|
|
||||||
if (it == endpoints.end())
|
|
||||||
return;
|
|
||||||
it->second.send = {};
|
|
||||||
if (!it->second.receive)
|
|
||||||
endpoints.erase(it);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
|
|
||||||
ChannelId channelId = getChannelId(receiveOp);
|
|
||||||
auto it = endpoints.find(channelId);
|
|
||||||
if (it == endpoints.end())
|
|
||||||
return;
|
|
||||||
it->second.receive = {};
|
|
||||||
if (!it->second.send)
|
|
||||||
endpoints.erase(it);
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
|
|
||||||
auto it = endpoints.find(id);
|
|
||||||
if (it == endpoints.end())
|
|
||||||
return failure();
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
|
|
||||||
auto endpointsOr = lookup(getChannelId(sendOp));
|
|
||||||
if (failed(endpointsOr) || !endpointsOr->receive)
|
|
||||||
return failure();
|
|
||||||
return endpointsOr->receive;
|
|
||||||
}
|
|
||||||
|
|
||||||
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
|
|
||||||
auto endpointsOr = lookup(getChannelId(receiveOp));
|
|
||||||
if (failed(endpointsOr) || !endpointsOr->send)
|
|
||||||
return failure();
|
|
||||||
return endpointsOr->send;
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult Channels::verify() const {
|
|
||||||
for (const auto& [channelId, pair] : endpoints) {
|
|
||||||
if (!pair.send || !pair.receive) {
|
|
||||||
if (pair.send) {
|
|
||||||
auto sendOp = pair.send;
|
|
||||||
sendOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_receive";
|
|
||||||
}
|
|
||||||
else if (pair.receive) {
|
|
||||||
auto receiveOp = pair.receive;
|
|
||||||
receiveOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_send";
|
|
||||||
}
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (failed(verifyEndpointPair(pair)))
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir::spatial
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/Support/LogicalResult.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir::spatial {
|
|
||||||
|
|
||||||
struct ChannelEndpoints {
|
|
||||||
SpatChannelSendOp send;
|
|
||||||
SpatChannelReceiveOp receive;
|
|
||||||
};
|
|
||||||
|
|
||||||
class Channels {
|
|
||||||
public:
|
|
||||||
using ChannelId = int64_t;
|
|
||||||
|
|
||||||
explicit Channels(mlir::func::FuncOp funcOp);
|
|
||||||
|
|
||||||
ChannelId allocate();
|
|
||||||
|
|
||||||
void insertSend(SpatChannelSendOp sendOp);
|
|
||||||
void insertReceive(SpatChannelReceiveOp receiveOp);
|
|
||||||
void eraseSend(SpatChannelSendOp sendOp);
|
|
||||||
void eraseReceive(SpatChannelReceiveOp receiveOp);
|
|
||||||
|
|
||||||
llvm::FailureOr<ChannelEndpoints> lookup(ChannelId id) const;
|
|
||||||
llvm::FailureOr<SpatChannelReceiveOp> getReceiveFor(SpatChannelSendOp sendOp) const;
|
|
||||||
llvm::FailureOr<SpatChannelSendOp> getSendFor(SpatChannelReceiveOp receiveOp) const;
|
|
||||||
|
|
||||||
mlir::LogicalResult verify() const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
ChannelId nextChannelId = 0;
|
|
||||||
llvm::DenseMap<ChannelId, ChannelEndpoints> endpoints;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace onnx_mlir::spatial
|
|
||||||
@@ -9,6 +9,7 @@ def SpatialDialect : Dialect {
|
|||||||
let name = "spat";
|
let name = "spat";
|
||||||
let summary = "Dialect designed for deep learning computation in a spatial architecture";
|
let summary = "Dialect designed for deep learning computation in a spatial architecture";
|
||||||
let cppNamespace = "::onnx_mlir::spatial";
|
let cppNamespace = "::onnx_mlir::spatial";
|
||||||
|
let useDefaultTypePrinterParser = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
class SpatOp<string mnemonic, list<Trait> traits = []> :
|
class SpatOp<string mnemonic, list<Trait> traits = []> :
|
||||||
@@ -18,6 +19,15 @@ class SpatOp<string mnemonic, list<Trait> traits = []> :
|
|||||||
def SpatTensor :
|
def SpatTensor :
|
||||||
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
|
||||||
|
|
||||||
|
class SpatType<string name, string typeMnemonic, list<Trait> traits = []>
|
||||||
|
: TypeDef<SpatialDialect, name, traits> {
|
||||||
|
let mnemonic = typeMnemonic;
|
||||||
|
}
|
||||||
|
|
||||||
|
def SpatChannelType : SpatType<"SpatChannel", "ch"> {
|
||||||
|
let summary = "Virtual channel type";
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Execution
|
// Execution
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -38,27 +48,10 @@ def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
|||||||
|
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def SpatComputeBatch : SpatOp<"compute_batch",
|
let assemblyFormat = [{
|
||||||
[SingleBlock, AttrSizedOperandSegments]> {
|
`[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body
|
||||||
let summary = "Compressed batch of independent equivalent compute lanes";
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
I32Attr:$laneCount,
|
|
||||||
Variadic<SpatTensor>:$weights,
|
|
||||||
Variadic<SpatTensor>:$inputs
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
Variadic<SpatTensor>:$outputs
|
|
||||||
);
|
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
||||||
@@ -68,66 +61,51 @@ def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
|||||||
Variadic<SpatTensor>:$outputs
|
Variadic<SpatTensor>:$outputs
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
let assemblyFormat = [{
|
||||||
}
|
$outputs attr-dict `:` type($outputs)
|
||||||
|
}];
|
||||||
def SpatExtractRowsOp : SpatOp<"extract_rows", []> {
|
|
||||||
let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
SpatTensor:$input
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
Variadic<SpatTensor>:$outputs
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def SpatConcatOp : SpatOp<"concat", []> {
|
|
||||||
let summary = "Concatenate tensors with compact Spatial operand syntax";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
I64Attr:$axis,
|
|
||||||
Variadic<SpatTensor>:$inputs
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
SpatTensor:$output
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Communication
|
// Communication
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def SpatChannelNewOp : SpatOp<"channel_new", []> {
|
||||||
|
let summary = "Create a new virtual channel";
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SpatChannelType:$channel
|
||||||
|
);
|
||||||
|
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<(ins ), [{
|
||||||
|
$_state.addTypes(SpatChannelType());
|
||||||
|
}]>
|
||||||
|
];
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
attr-dict
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
||||||
let summary = "Send a tensor through a logical channel";
|
let summary = "Send a tensor through a channel";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
I64Attr:$channelId,
|
SpatChannelType:$channel,
|
||||||
I32Attr:$sourceCoreId,
|
|
||||||
I32Attr:$targetCoreId,
|
|
||||||
SpatTensor:$input
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$input attr-dict `:` type($input)
|
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||||
let summary = "Receive a tensor from a logical channel";
|
let summary = "Receive a tensor from a channel";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
I64Attr:$channelId,
|
SpatChannelType:$channel
|
||||||
I32Attr:$sourceCoreId,
|
|
||||||
I32Attr:$targetCoreId
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
@@ -135,70 +113,37 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
attr-dict `:` type($output)
|
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelSendManyOp : SpatOp<"channel_send_many", []> {
|
def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> {
|
||||||
let summary = "Send multiple tensors through logical channels";
|
let summary = "Broadcast a tensor through a shared channel buffer";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
DenseI64ArrayAttr:$channelIds,
|
SpatChannelType:$channel,
|
||||||
DenseI32ArrayAttr:$sourceCoreIds,
|
|
||||||
DenseI32ArrayAttr:$targetCoreIds,
|
|
||||||
Variadic<SpatTensor>:$inputs
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> {
|
|
||||||
let summary = "Receive multiple tensors from logical channels";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
DenseI64ArrayAttr:$channelIds,
|
|
||||||
DenseI32ArrayAttr:$sourceCoreIds,
|
|
||||||
DenseI32ArrayAttr:$targetCoreIds
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs
|
|
||||||
Variadic<SpatTensor>:$outputs
|
|
||||||
);
|
|
||||||
|
|
||||||
let hasVerifier = 1;
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
|
||||||
let summary = "Send per-lane tensors through logical channels in a batch body";
|
|
||||||
|
|
||||||
let arguments = (ins
|
|
||||||
DenseI64ArrayAttr:$channelIds,
|
|
||||||
DenseI32ArrayAttr:$sourceCoreIds,
|
|
||||||
DenseI32ArrayAttr:$targetCoreIds,
|
|
||||||
SpatTensor:$input
|
SpatTensor:$input
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let assemblyFormat = [{
|
||||||
let hasCustomAssemblyFormat = 1;
|
$input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)`
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> {
|
||||||
let summary = "Receive a per-lane tensor through logical channels in a batch body";
|
let summary = "Receive a tensor from a shared channel buffer";
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
DenseI64ArrayAttr:$channelIds,
|
SpatChannelType:$channel
|
||||||
DenseI32ArrayAttr:$sourceCoreIds,
|
|
||||||
DenseI32ArrayAttr:$targetCoreIds
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
SpatTensor:$output
|
SpatTensor:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasVerifier = 1;
|
let assemblyFormat = [{
|
||||||
let hasCustomAssemblyFormat = 1;
|
$channel attr-dict `:` `(` type($channel) `->` type($output) `)`
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -3,17 +3,15 @@
|
|||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/IR/ValueRange.h"
|
#include "mlir/IR/ValueRange.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
|
||||||
#include "llvm/Support/raw_ostream.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
#include <map>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <queue>
|
#include <set>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -28,11 +26,9 @@ namespace spatial {
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
using SpatCompute = onnx_mlir::spatial::SpatCompute;
|
|
||||||
using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch;
|
|
||||||
|
|
||||||
struct VirtualNode {
|
struct VirtualNode {
|
||||||
SmallVector<size_t, 4> originalComputeIndices;
|
llvm::SmallVector<size_t, 4> originalComputeIndices;
|
||||||
Weight weight = 0;
|
Weight weight = 0;
|
||||||
CrossbarUsage crossbarUsage = 0;
|
CrossbarUsage crossbarUsage = 0;
|
||||||
};
|
};
|
||||||
@@ -51,52 +47,11 @@ struct TimingInfo {
|
|||||||
|
|
||||||
struct WindowScheduleResult {
|
struct WindowScheduleResult {
|
||||||
std::vector<std::vector<size_t>> mergeGroups;
|
std::vector<std::vector<size_t>> mergeGroups;
|
||||||
CPU cpuCount = 0;
|
bool usedAllAvailableCpus = false;
|
||||||
size_t mergedNodeCount = 0;
|
|
||||||
size_t maxMergeGroupSize = 0;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
constexpr CPU kDefaultMaxCpuCount = 1000;
|
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
||||||
|
std::map<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||||
size_t getSchedulingCpuBudget() {
|
|
||||||
if (coresCount.getValue() > 0)
|
|
||||||
return static_cast<size_t>(coresCount.getValue());
|
|
||||||
return static_cast<size_t>(kDefaultMaxCpuCount);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
|
||||||
assert(laneCount > 0 && "laneCount must be positive");
|
|
||||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
|
||||||
}
|
|
||||||
|
|
||||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
|
||||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
|
||||||
size_t baseChunkSize = totalLanes / chunkCount;
|
|
||||||
size_t largeChunkCount = totalLanes % chunkCount;
|
|
||||||
|
|
||||||
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
|
||||||
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
|
||||||
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
|
||||||
}
|
|
||||||
|
|
||||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
|
||||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
|
||||||
size_t baseChunkSize = totalLanes / chunkCount;
|
|
||||||
size_t largeChunkCount = totalLanes % chunkCount;
|
|
||||||
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
|
||||||
|
|
||||||
size_t chunkIndex = 0;
|
|
||||||
if (static_cast<size_t>(lane) < largeChunkSpan)
|
|
||||||
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
|
||||||
else
|
|
||||||
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
|
||||||
return getBatchChunkForIndex(batch, chunkIndex);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
|
||||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
|
||||||
for (auto [start, end, weight] : edges) {
|
for (auto [start, end, weight] : edges) {
|
||||||
size_t startIndex = static_cast<size_t>(start);
|
size_t startIndex = static_cast<size_t>(start);
|
||||||
size_t endIndex = static_cast<size_t>(end);
|
size_t endIndex = static_cast<size_t>(end);
|
||||||
@@ -104,9 +59,11 @@ std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
|||||||
continue;
|
continue;
|
||||||
auto key = std::make_pair(startIndex, endIndex);
|
auto key = std::make_pair(startIndex, endIndex);
|
||||||
Weight edgeWeight = static_cast<Weight>(weight);
|
Weight edgeWeight = static_cast<Weight>(weight);
|
||||||
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
|
auto it = edgeWeights.find(key);
|
||||||
if (!inserted.second)
|
if (it == edgeWeights.end())
|
||||||
inserted.first->second = std::max(inserted.first->second, edgeWeight);
|
edgeWeights.insert({key, edgeWeight});
|
||||||
|
else
|
||||||
|
it->second = std::max(it->second, edgeWeight);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<IndexedEdge> aggregatedEdges;
|
std::vector<IndexedEdge> aggregatedEdges;
|
||||||
@@ -114,104 +71,18 @@ std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
|||||||
for (auto [key, weight] : edgeWeights)
|
for (auto [key, weight] : edgeWeights)
|
||||||
aggregatedEdges.push_back(
|
aggregatedEdges.push_back(
|
||||||
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
||||||
llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) {
|
|
||||||
if (std::get<0>(lhs) != std::get<0>(rhs))
|
|
||||||
return std::get<0>(lhs) < std::get<0>(rhs);
|
|
||||||
return std::get<1>(lhs) < std::get<1>(rhs);
|
|
||||||
});
|
|
||||||
return aggregatedEdges;
|
return aggregatedEdges;
|
||||||
}
|
}
|
||||||
|
|
||||||
Weight getComputeBodyWeight(Region& body) {
|
VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef<SpatCompute> spatComputes,
|
||||||
constexpr Weight kOperationWeight = 100;
|
llvm::ArrayRef<IndexedEdge> edges) {
|
||||||
Weight numOperations = 0;
|
|
||||||
for (auto& block : body)
|
|
||||||
for ([[maybe_unused]] auto& op : block)
|
|
||||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
|
||||||
return checkedMultiply(numOperations, kOperationWeight);
|
|
||||||
}
|
|
||||||
|
|
||||||
CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
|
|
||||||
CrossbarUsage crossbarUsage = 0;
|
|
||||||
for (auto& block : body)
|
|
||||||
for (auto& op : block)
|
|
||||||
if (isa<SpatWeightedVMMOp>(op))
|
|
||||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
|
||||||
return crossbarUsage;
|
|
||||||
}
|
|
||||||
|
|
||||||
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
|
||||||
return getSpatComputeWeight(spatCompute);
|
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
|
||||||
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
|
||||||
}
|
|
||||||
|
|
||||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
|
||||||
return getSpatComputeCrossbarUsage(spatCompute);
|
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
|
||||||
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance& instance) {
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
|
||||||
return SmallVector<Value, 4>(spatCompute.getInputs().begin(), spatCompute.getInputs().end());
|
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
|
||||||
SmallVector<Value, 4> inputs;
|
|
||||||
inputs.reserve(instance.laneCount);
|
|
||||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
|
||||||
inputs.push_back(batch.getInputs()[lane]);
|
|
||||||
return inputs;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<ComputeInstance> getOriginalComputeInstance(Value value) {
|
|
||||||
Operation* op = value.getDefiningOp();
|
|
||||||
if (!op)
|
|
||||||
return std::nullopt;
|
|
||||||
|
|
||||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
|
||||||
value = extract.getSource();
|
|
||||||
op = value.getDefiningOp();
|
|
||||||
if (!op)
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(op))
|
|
||||||
return ComputeInstance {spatCompute.getOperation(), 0, 1};
|
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(op))
|
|
||||||
return getBatchChunkForLane(batch, static_cast<uint32_t>(cast<OpResult>(value).getResultNumber()));
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<ComputeInstance> collectComputeInstances(Operation* entryOp) {
|
|
||||||
SmallVector<ComputeInstance> instances;
|
|
||||||
for (Region& region : entryOp->getRegions()) {
|
|
||||||
for (Block& block : region) {
|
|
||||||
for (Operation& op : block) {
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
|
||||||
instances.push_back({spatCompute.getOperation(), 0, 1});
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
|
||||||
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex)
|
|
||||||
instances.push_back(getBatchChunkForIndex(batch, chunkIndex));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return instances;
|
|
||||||
}
|
|
||||||
|
|
||||||
VirtualGraph buildInitialVirtualGraph(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges) {
|
|
||||||
VirtualGraph graph;
|
VirtualGraph graph;
|
||||||
graph.nodes.reserve(computeInstances.size());
|
graph.nodes.reserve(spatComputes.size());
|
||||||
for (auto [index, computeInstance] : llvm::enumerate(computeInstances)) {
|
for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) {
|
||||||
VirtualNode node;
|
VirtualNode node;
|
||||||
node.originalComputeIndices.push_back(index);
|
node.originalComputeIndices.push_back(index);
|
||||||
node.weight = getComputeInstanceWeight(computeInstance);
|
node.weight = getSpatComputeWeight(spatCompute);
|
||||||
node.crossbarUsage = getComputeInstanceCrossbarUsage(computeInstance);
|
node.crossbarUsage = getSpatComputeCrossbarUsage(spatCompute);
|
||||||
graph.nodes.push_back(std::move(node));
|
graph.nodes.push_back(std::move(node));
|
||||||
}
|
}
|
||||||
graph.edges = aggregateEdges(edges);
|
graph.edges = aggregateEdges(edges);
|
||||||
@@ -239,34 +110,22 @@ TimingInfo computeTiming(const VirtualGraph& graph) {
|
|||||||
incomingEdgeCount[endIndex]++;
|
incomingEdgeCount[endIndex]++;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
|
std::vector<size_t> readyNodes;
|
||||||
const VirtualNode& node = graph.nodes[nodeIndex];
|
readyNodes.reserve(nodeCount);
|
||||||
if (!node.originalComputeIndices.empty())
|
|
||||||
return node.originalComputeIndices.front();
|
|
||||||
return nodeIndex;
|
|
||||||
};
|
|
||||||
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
|
||||||
size_t lhsKey = getVirtualNodeOrderKey(lhs);
|
|
||||||
size_t rhsKey = getVirtualNodeOrderKey(rhs);
|
|
||||||
if (lhsKey != rhsKey)
|
|
||||||
return lhsKey > rhsKey;
|
|
||||||
return lhs > rhs;
|
|
||||||
};
|
|
||||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
|
||||||
for (size_t i = 0; i < nodeCount; ++i)
|
for (size_t i = 0; i < nodeCount; ++i)
|
||||||
if (incomingEdgeCount[i] == 0)
|
if (incomingEdgeCount[i] == 0)
|
||||||
readyNodes.push(i);
|
readyNodes.push_back(i);
|
||||||
|
|
||||||
while (!readyNodes.empty()) {
|
size_t readyIndex = 0;
|
||||||
size_t current = readyNodes.top();
|
while (readyIndex != readyNodes.size()) {
|
||||||
readyNodes.pop();
|
size_t current = readyNodes[readyIndex++];
|
||||||
timing.topologicalOrder.push_back(current);
|
timing.topologicalOrder.push_back(current);
|
||||||
for (auto [child, weight] : children[current]) {
|
for (auto [child, weight] : children[current]) {
|
||||||
(void) weight;
|
(void) weight;
|
||||||
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
||||||
incomingEdgeCount[child]--;
|
incomingEdgeCount[child]--;
|
||||||
if (incomingEdgeCount[child] == 0)
|
if (incomingEdgeCount[child] == 0)
|
||||||
readyNodes.push(child);
|
readyNodes.push_back(child);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -299,27 +158,10 @@ TimingInfo computeTiming(const VirtualGraph& graph) {
|
|||||||
return timing;
|
return timing;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph& graph) {
|
std::vector<size_t> selectCriticalWindow(const TimingInfo& timing, size_t windowSize) {
|
||||||
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
|
std::vector<size_t> selected(timing.aest.size());
|
||||||
for (auto [start, end, weight] : graph.edges) {
|
std::iota(selected.begin(), selected.end(), 0);
|
||||||
(void) weight;
|
std::stable_sort(selected.begin(), selected.end(), [&](size_t lhs, size_t rhs) {
|
||||||
size_t startIndex = static_cast<size_t>(start);
|
|
||||||
size_t endIndex = static_cast<size_t>(end);
|
|
||||||
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
|
|
||||||
adjacency[startIndex].push_back(endIndex);
|
|
||||||
adjacency[endIndex].push_back(startIndex);
|
|
||||||
}
|
|
||||||
for (auto& neighbours : adjacency) {
|
|
||||||
llvm::sort(neighbours);
|
|
||||||
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
|
|
||||||
}
|
|
||||||
return adjacency;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) {
|
|
||||||
std::vector<size_t> ranked(timing.aest.size());
|
|
||||||
std::iota(ranked.begin(), ranked.end(), 0);
|
|
||||||
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
|
|
||||||
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
||||||
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
||||||
if (lhsSlack != rhsSlack)
|
if (lhsSlack != rhsSlack)
|
||||||
@@ -327,85 +169,21 @@ std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const Timing
|
|||||||
if (timing.aest[lhs] != timing.aest[rhs])
|
if (timing.aest[lhs] != timing.aest[rhs])
|
||||||
return timing.aest[lhs] < timing.aest[rhs];
|
return timing.aest[lhs] < timing.aest[rhs];
|
||||||
return lhs < rhs;
|
return lhs < rhs;
|
||||||
};
|
});
|
||||||
|
selected.resize(std::min(windowSize, selected.size()));
|
||||||
windowSize = std::min(windowSize, ranked.size());
|
|
||||||
if (windowSize == 0)
|
|
||||||
return {};
|
|
||||||
if (windowSize == ranked.size()) {
|
|
||||||
llvm::sort(ranked, isHigherPriority);
|
|
||||||
return ranked;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
|
|
||||||
if (criticalPoolSize < ranked.size())
|
|
||||||
std::nth_element(
|
|
||||||
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
|
|
||||||
|
|
||||||
std::vector<char> inCriticalPool(ranked.size(), false);
|
|
||||||
for (size_t i = 0; i < criticalPoolSize; ++i)
|
|
||||||
inCriticalPool[ranked[i]] = true;
|
|
||||||
|
|
||||||
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
|
|
||||||
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
|
|
||||||
std::vector<size_t> selected;
|
|
||||||
std::vector<char> inWindow(ranked.size(), false);
|
|
||||||
selected.reserve(windowSize);
|
|
||||||
|
|
||||||
struct FrontierEntry {
|
|
||||||
size_t node;
|
|
||||||
};
|
|
||||||
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
|
|
||||||
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
|
|
||||||
|
|
||||||
auto addToWindow = [&](size_t node, const std::vector<char>& eligible) {
|
|
||||||
if (inWindow[node])
|
|
||||||
return;
|
|
||||||
inWindow[node] = true;
|
|
||||||
selected.push_back(node);
|
|
||||||
for (size_t neighbour : adjacency[node])
|
|
||||||
if (!inWindow[neighbour] && eligible[neighbour])
|
|
||||||
frontier.push({neighbour});
|
|
||||||
};
|
|
||||||
|
|
||||||
addToWindow(seed, inCriticalPool);
|
|
||||||
while (!frontier.empty() && selected.size() < windowSize) {
|
|
||||||
size_t node = frontier.top().node;
|
|
||||||
frontier.pop();
|
|
||||||
if (!inWindow[node])
|
|
||||||
addToWindow(node, inCriticalPool);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (selected.size() < windowSize) {
|
|
||||||
std::vector<char> anyNode(ranked.size(), true);
|
|
||||||
for (size_t node : selected)
|
|
||||||
for (size_t neighbour : adjacency[node])
|
|
||||||
if (!inWindow[neighbour])
|
|
||||||
frontier.push({neighbour});
|
|
||||||
while (!frontier.empty() && selected.size() < windowSize) {
|
|
||||||
size_t node = frontier.top().node;
|
|
||||||
frontier.pop();
|
|
||||||
if (!inWindow[node])
|
|
||||||
addToWindow(node, anyNode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (selected.size() < windowSize) {
|
|
||||||
llvm::sort(ranked, isHigherPriority);
|
|
||||||
for (size_t node : ranked) {
|
|
||||||
if (selected.size() == windowSize)
|
|
||||||
break;
|
|
||||||
if (!inWindow[node]) {
|
|
||||||
inWindow[node] = true;
|
|
||||||
selected.push_back(node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::sort(selected, isHigherPriority);
|
|
||||||
return selected;
|
return selected;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes) {
|
||||||
|
std::vector<size_t> signature;
|
||||||
|
for (size_t nodeIndex : selectedNodes) {
|
||||||
|
const VirtualNode& node = graph.nodes[nodeIndex];
|
||||||
|
signature.insert(signature.end(), node.originalComputeIndices.begin(), node.originalComputeIndices.end());
|
||||||
|
}
|
||||||
|
std::sort(signature.begin(), signature.end());
|
||||||
|
return signature;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
|
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
|
||||||
std::vector<IndexedEdge> windowEdges;
|
std::vector<IndexedEdge> windowEdges;
|
||||||
windowEdges.reserve(graph.edges.size());
|
windowEdges.reserve(graph.edges.size());
|
||||||
@@ -419,71 +197,48 @@ std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::
|
|||||||
return aggregateEdges(windowEdges);
|
return aggregateEdges(windowEdges);
|
||||||
}
|
}
|
||||||
|
|
||||||
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
|
WindowScheduleResult
|
||||||
|
scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes, MLIRContext* context) {
|
||||||
std::vector<Weight> windowWeights;
|
std::vector<Weight> windowWeights;
|
||||||
std::vector<CrossbarUsage> windowCrossbarUsage;
|
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||||
std::vector<int64_t> windowNodeOrderKeys;
|
|
||||||
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
||||||
windowWeights.reserve(selectedNodes.size());
|
windowWeights.reserve(selectedNodes.size());
|
||||||
windowCrossbarUsage.reserve(selectedNodes.size());
|
windowCrossbarUsage.reserve(selectedNodes.size());
|
||||||
windowNodeOrderKeys.reserve(selectedNodes.size());
|
|
||||||
|
|
||||||
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
||||||
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
||||||
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
||||||
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
||||||
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphDCP windowGraph(
|
GraphDCP windowGraph(windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowCrossbarUsage);
|
||||||
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
|
|
||||||
if (coresCount.getValue() > 0)
|
if (coresCount.getValue() > 0)
|
||||||
windowGraph.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
windowGraph.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||||
windowGraph.setContext(context);
|
windowGraph.setContext(context);
|
||||||
windowGraph.runDcp();
|
windowGraph.runDcp();
|
||||||
|
|
||||||
WindowScheduleResult result;
|
WindowScheduleResult result;
|
||||||
result.cpuCount = windowGraph.cpuCount();
|
result.usedAllAvailableCpus = windowGraph.cpuCount() >= windowGraph.getMaxCpuCount();
|
||||||
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
||||||
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
||||||
if (scheduledTasks.size() < 2)
|
if (scheduledTasks.size() < 2)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
result.mergedNodeCount += scheduledTasks.size();
|
|
||||||
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
|
|
||||||
std::vector<size_t> mergeGroup;
|
std::vector<size_t> mergeGroup;
|
||||||
mergeGroup.reserve(scheduledTasks.size());
|
mergeGroup.reserve(scheduledTasks.size());
|
||||||
for (const auto& task : scheduledTasks)
|
for (const auto& task : scheduledTasks)
|
||||||
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
||||||
|
std::sort(mergeGroup.begin(), mergeGroup.end());
|
||||||
result.mergeGroups.push_back(std::move(mergeGroup));
|
result.mergeGroups.push_back(std::move(mergeGroup));
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool coarsenGraph(const VirtualGraph& graph,
|
bool coarsenGraph(const VirtualGraph& graph,
|
||||||
ArrayRef<std::vector<size_t>> mergeGroups,
|
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
||||||
VirtualGraph& coarsenedGraph,
|
VirtualGraph& coarsenedGraph) {
|
||||||
std::vector<size_t>& oldToNewNode) {
|
|
||||||
TimingInfo timing = computeTiming(graph);
|
|
||||||
std::vector<size_t> topologicalRank(graph.nodes.size());
|
|
||||||
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
|
|
||||||
if (timing.valid)
|
|
||||||
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
|
|
||||||
topologicalRank[nodeIndex] = rank;
|
|
||||||
|
|
||||||
std::vector<std::vector<size_t>> orderedMergeGroups;
|
|
||||||
orderedMergeGroups.reserve(mergeGroups.size());
|
|
||||||
for (const auto& mergeGroup : mergeGroups) {
|
|
||||||
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
|
|
||||||
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
|
|
||||||
if (topologicalRank[lhs] != topologicalRank[rhs])
|
|
||||||
return topologicalRank[lhs] < topologicalRank[rhs];
|
|
||||||
return lhs < rhs;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
||||||
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
|
for (auto [groupIndex, mergeGroup] : llvm::enumerate(mergeGroups)) {
|
||||||
if (mergeGroup.size() < 2)
|
if (mergeGroup.size() < 2)
|
||||||
continue;
|
continue;
|
||||||
for (size_t nodeIndex : mergeGroup) {
|
for (size_t nodeIndex : mergeGroup) {
|
||||||
@@ -492,21 +247,18 @@ bool coarsenGraph(const VirtualGraph& graph,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
|
std::vector<std::optional<size_t>> mergeGroupToNewNode(mergeGroups.size());
|
||||||
std::vector<size_t> newNodeRank;
|
std::vector<size_t> oldToNewNode(graph.nodes.size(), 0);
|
||||||
oldToNewNode.assign(graph.nodes.size(), 0);
|
|
||||||
bool mergedAny = false;
|
bool mergedAny = false;
|
||||||
coarsenedGraph.nodes.clear();
|
coarsenedGraph.nodes.clear();
|
||||||
coarsenedGraph.edges.clear();
|
coarsenedGraph.edges.clear();
|
||||||
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
||||||
newNodeRank.reserve(graph.nodes.size());
|
|
||||||
|
|
||||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||||
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
||||||
if (mergeGroupIndex == -1) {
|
if (mergeGroupIndex == -1) {
|
||||||
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
||||||
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
||||||
newNodeRank.push_back(topologicalRank[nodeIndex]);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -517,7 +269,7 @@ bool coarsenGraph(const VirtualGraph& graph,
|
|||||||
}
|
}
|
||||||
|
|
||||||
VirtualNode mergedNode;
|
VirtualNode mergedNode;
|
||||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
for (size_t memberIndex : mergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||||
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
||||||
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
|
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
|
||||||
memberNode.originalComputeIndices.end());
|
memberNode.originalComputeIndices.end());
|
||||||
@@ -528,9 +280,8 @@ bool coarsenGraph(const VirtualGraph& graph,
|
|||||||
|
|
||||||
mergedAny = true;
|
mergedAny = true;
|
||||||
newNodeIndex = coarsenedGraph.nodes.size();
|
newNodeIndex = coarsenedGraph.nodes.size();
|
||||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
for (size_t memberIndex : mergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
||||||
oldToNewNode[memberIndex] = *newNodeIndex;
|
oldToNewNode[memberIndex] = *newNodeIndex;
|
||||||
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
|
|
||||||
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -544,96 +295,87 @@ bool coarsenGraph(const VirtualGraph& graph,
|
|||||||
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
||||||
if (newStart == newEnd)
|
if (newStart == newEnd)
|
||||||
continue;
|
continue;
|
||||||
if (newNodeRank[newStart] >= newNodeRank[newEnd])
|
|
||||||
continue;
|
|
||||||
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
||||||
}
|
}
|
||||||
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
||||||
|
|
||||||
return true;
|
return computeTiming(coarsenedGraph).valid;
|
||||||
}
|
}
|
||||||
|
|
||||||
CPU getVirtualGraphMaxCpuCount() { return static_cast<CPU>(getSchedulingCpuBudget()); }
|
bool coarsenGraphWithFallback(const VirtualGraph& graph,
|
||||||
|
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
||||||
|
VirtualGraph& coarsenedGraph) {
|
||||||
|
if (coarsenGraph(graph, mergeGroups, coarsenedGraph))
|
||||||
|
return true;
|
||||||
|
|
||||||
size_t getDcpCoarseningWindowSize(size_t nodeCount) {
|
std::vector<size_t> orderedGroupIndices(mergeGroups.size());
|
||||||
size_t windowSize = std::min(dcpCriticalWindowSize.getValue(), nodeCount);
|
std::iota(orderedGroupIndices.begin(), orderedGroupIndices.end(), 0);
|
||||||
CPU maxCpuCount = std::max<CPU>(1, getVirtualGraphMaxCpuCount());
|
std::stable_sort(orderedGroupIndices.begin(), orderedGroupIndices.end(), [&](size_t lhs, size_t rhs) {
|
||||||
if (nodeCount > static_cast<size_t>(maxCpuCount))
|
return mergeGroups[lhs].size() > mergeGroups[rhs].size();
|
||||||
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
|
});
|
||||||
return windowSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef<ComputeInstance> computeInstances) {
|
std::vector<std::vector<size_t>> acceptedMergeGroups;
|
||||||
DCPAnalysisResult result;
|
acceptedMergeGroups.reserve(mergeGroups.size());
|
||||||
|
for (size_t groupIndex : orderedGroupIndices) {
|
||||||
|
std::vector<std::vector<size_t>> candidateMergeGroups = acceptedMergeGroups;
|
||||||
|
candidateMergeGroups.push_back(mergeGroups[groupIndex]);
|
||||||
|
|
||||||
TimingInfo timing = computeTiming(graph);
|
VirtualGraph candidateGraph;
|
||||||
std::vector<size_t> virtualNodeOrder;
|
if (!coarsenGraph(graph, candidateMergeGroups, candidateGraph))
|
||||||
if (timing.valid) {
|
|
||||||
virtualNodeOrder = std::move(timing.topologicalOrder);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
virtualNodeOrder.resize(graph.nodes.size());
|
|
||||||
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<size_t> originalComputeToCpu(computeInstances.size(), 0);
|
|
||||||
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
|
|
||||||
const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
|
|
||||||
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
|
||||||
originalComputeToCpu[originalIndex] = cpu;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.dominanceOrderCompute.reserve(computeInstances.size());
|
|
||||||
for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) {
|
|
||||||
size_t cpu = originalComputeToCpu[originalIndex];
|
|
||||||
result.dominanceOrderCompute.push_back(computeInstance);
|
|
||||||
result.computeToCpuMap[computeInstance] = cpu;
|
|
||||||
result.cpuToLastComputeMap[cpu] = computeInstance;
|
|
||||||
}
|
|
||||||
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
|
|
||||||
result.isLastComputeOfCpu.insert(lastCompute);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef<ComputeInstance> computeInstances) {
|
|
||||||
DCPAnalysisResult result;
|
|
||||||
result.dominanceOrderCompute.assign(computeInstances.begin(), computeInstances.end());
|
|
||||||
|
|
||||||
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
|
||||||
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
|
|
||||||
if (scheduledTasks.empty())
|
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
for (const auto& task : scheduledTasks)
|
acceptedMergeGroups = std::move(candidateMergeGroups);
|
||||||
result.computeToCpuMap[computeInstances[task.nodeIndex]] = cpu;
|
coarsenedGraph = std::move(candidateGraph);
|
||||||
result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex];
|
}
|
||||||
result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]);
|
return !acceptedMergeGroups.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::ArrayRef<IndexedEdge> edges) {
|
||||||
|
VirtualGraph graph;
|
||||||
|
graph.nodes.resize(computeCount);
|
||||||
|
graph.edges = aggregateEdges(edges);
|
||||||
|
TimingInfo timing = computeTiming(graph);
|
||||||
|
if (timing.valid)
|
||||||
|
return timing.topologicalOrder;
|
||||||
|
|
||||||
|
std::vector<size_t> fallbackOrder(computeCount);
|
||||||
|
std::iota(fallbackOrder.begin(), fallbackOrder.end(), 0);
|
||||||
|
return fallbackOrder;
|
||||||
|
}
|
||||||
|
|
||||||
|
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
|
||||||
|
llvm::ArrayRef<SpatCompute> spatComputes,
|
||||||
|
llvm::ArrayRef<IndexedEdge> originalEdges) {
|
||||||
|
DCPAnalysisResult result;
|
||||||
|
std::vector<size_t> originalToVirtualNode(spatComputes.size(), 0);
|
||||||
|
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
|
||||||
|
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
||||||
|
originalToVirtualNode[originalIndex] = virtualNodeIndex;
|
||||||
|
|
||||||
|
auto dominanceOrder = computeOriginalTopologicalOrder(spatComputes.size(), originalEdges);
|
||||||
|
result.dominanceOrderCompute.reserve(dominanceOrder.size());
|
||||||
|
for (size_t originalIndex : dominanceOrder) {
|
||||||
|
SpatCompute spatCompute = spatComputes[originalIndex];
|
||||||
|
size_t cpu = originalToVirtualNode[originalIndex];
|
||||||
|
result.dominanceOrderCompute.push_back(spatCompute);
|
||||||
|
result.computeToCpuMap[spatCompute] = cpu;
|
||||||
|
result.cpuToLastComputeMap[cpu] = spatCompute;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto [cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||||
|
result.isLastComputeOfCpu.insert(lastCompute);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
DCPAnalysisResult
|
DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatCompute> spatComputes,
|
||||||
runLegacyDcp(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
|
llvm::ArrayRef<IndexedEdge> edges,
|
||||||
SmallVector<Weight> nodeWeights;
|
MLIRContext* context) {
|
||||||
SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
GraphDCP graphDCP(spatComputes, edges);
|
||||||
SmallVector<int64_t> nodeOrderKeys;
|
|
||||||
nodeWeights.reserve(computeInstances.size());
|
|
||||||
nodeCrossbarUsage.reserve(computeInstances.size());
|
|
||||||
nodeOrderKeys.reserve(computeInstances.size());
|
|
||||||
for (auto [index, instance] : llvm::enumerate(computeInstances)) {
|
|
||||||
nodeWeights.push_back(getComputeInstanceWeight(instance));
|
|
||||||
nodeCrossbarUsage.push_back(getComputeInstanceCrossbarUsage(instance));
|
|
||||||
nodeOrderKeys.push_back(static_cast<int64_t>(index));
|
|
||||||
}
|
|
||||||
|
|
||||||
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
|
|
||||||
if (coresCount.getValue() > 0)
|
if (coresCount.getValue() > 0)
|
||||||
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||||
graphDCP.setContext(context);
|
graphDCP.setContext(context);
|
||||||
graphDCP.runDcp();
|
graphDCP.runDcp();
|
||||||
return buildResultFromScheduledGraph(graphDCP, computeInstances);
|
return graphDCP.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -641,117 +383,64 @@ runLegacyDcp(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> e
|
|||||||
SpatCompute getOriginalSpatCompute(Operation* op) {
|
SpatCompute getOriginalSpatCompute(Operation* op) {
|
||||||
if (!op)
|
if (!op)
|
||||||
return {};
|
return {};
|
||||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
op = extract.getSource().getDefiningOp();
|
op = extract.getSource().getDefiningOp();
|
||||||
if (!op)
|
if (!op)
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
if (auto res = dyn_cast<SpatCompute>(op))
|
if (auto res = llvm::dyn_cast<SpatCompute>(op))
|
||||||
return res;
|
return res;
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
DCPAnalysisResult DCPAnalysis::run() {
|
DCPAnalysisResult DCPAnalysis::run() {
|
||||||
SmallVector<ComputeInstance> computeInstances = collectComputeInstances(entryOp);
|
SmallVector<SpatCompute, 10> spatComputes;
|
||||||
SmallVector<IndexedEdge, 10> edges;
|
SmallVector<IndexedEdge, 10> edges;
|
||||||
|
for (auto& region : entryOp->getRegions())
|
||||||
|
for (SpatCompute spatCompute : region.getOps<SpatCompute>())
|
||||||
|
spatComputes.push_back(spatCompute);
|
||||||
|
|
||||||
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
for (auto [indexEndEdge, spatCompute] : llvm::enumerate(spatComputes)) {
|
||||||
instanceToIndex.reserve(computeInstances.size());
|
for (Value input : spatCompute.getInputs()) {
|
||||||
for (auto [index, instance] : llvm::enumerate(computeInstances))
|
if (auto producerCompute = getOriginalSpatCompute(input.getDefiningOp())) {
|
||||||
instanceToIndex[instance] = index;
|
auto producerIt = llvm::find(spatComputes, producerCompute);
|
||||||
|
assert(producerIt != spatComputes.end());
|
||||||
for (auto [indexEndEdge, computeInstance] : llvm::enumerate(computeInstances)) {
|
auto indexStartEdge = std::distance(spatComputes.begin(), producerIt);
|
||||||
for (Value input : getComputeInstanceInputs(computeInstance)) {
|
edges.push_back({indexStartEdge, indexEndEdge, getSizeInBytes(cast<ShapedType>(input.getType()))});
|
||||||
if (auto producerInstance = getOriginalComputeInstance(input)) {
|
|
||||||
auto producerIt = instanceToIndex.find(*producerInstance);
|
|
||||||
assert(producerIt != instanceToIndex.end());
|
|
||||||
auto indexStartEdge = producerIt->second;
|
|
||||||
edges.push_back({static_cast<int64_t>(indexStartEdge),
|
|
||||||
static_cast<int64_t>(indexEndEdge),
|
|
||||||
static_cast<int64_t>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dcpCriticalWindowSize.getValue() == 0)
|
if (dcpCriticalWindowSize.getValue() == 0)
|
||||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
return runLegacyDcp(spatComputes, edges, entryOp->getContext());
|
||||||
|
|
||||||
|
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatComputes, edges);
|
||||||
|
std::set<std::vector<size_t>> seenCriticalWindows;
|
||||||
|
while (virtualGraph.nodes.size() > 1) {
|
||||||
|
TimingInfo timing = computeTiming(virtualGraph);
|
||||||
|
if (!timing.valid)
|
||||||
|
break;
|
||||||
|
|
||||||
|
auto selectedNodes = selectCriticalWindow(timing, dcpCriticalWindowSize.getValue());
|
||||||
|
if (selectedNodes.size() < 2)
|
||||||
|
break;
|
||||||
|
|
||||||
|
if (!seenCriticalWindows.insert(getOriginalSignature(virtualGraph, selectedNodes)).second)
|
||||||
|
break;
|
||||||
|
|
||||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges);
|
|
||||||
size_t iteration = 0;
|
|
||||||
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
|
|
||||||
size_t oldNodeCount = virtualGraph.nodes.size();
|
|
||||||
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
|
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
|
||||||
if (windowSchedule.mergeGroups.empty()) {
|
if (windowSchedule.mergeGroups.empty())
|
||||||
if (oldNodeCount >= 200)
|
break;
|
||||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
|
||||||
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
|
|
||||||
iteration,
|
|
||||||
oldNodeCount,
|
|
||||||
selectedNodes.size(),
|
|
||||||
windowSchedule.cpuCount);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
VirtualGraph coarsenedGraph;
|
VirtualGraph coarsenedGraph;
|
||||||
std::vector<size_t> oldToNewNode;
|
if (!coarsenGraphWithFallback(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph))
|
||||||
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
|
break;
|
||||||
return false;
|
|
||||||
if (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200)
|
|
||||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
|
||||||
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
|
|
||||||
iteration,
|
|
||||||
oldNodeCount,
|
|
||||||
selectedNodes.size(),
|
|
||||||
windowSchedule.cpuCount,
|
|
||||||
windowSchedule.mergeGroups.size(),
|
|
||||||
windowSchedule.mergedNodeCount,
|
|
||||||
windowSchedule.maxMergeGroupSize,
|
|
||||||
coarsenedGraph.nodes.size(),
|
|
||||||
oldNodeCount - coarsenedGraph.nodes.size());
|
|
||||||
virtualGraph = std::move(coarsenedGraph);
|
virtualGraph = std::move(coarsenedGraph);
|
||||||
return true;
|
if (windowSchedule.usedAllAvailableCpus)
|
||||||
};
|
|
||||||
|
|
||||||
while (virtualGraph.nodes.size() > 1) {
|
|
||||||
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) {
|
|
||||||
if (virtualGraph.nodes.size() >= 200)
|
|
||||||
llvm::errs() << llvm::formatv(
|
|
||||||
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
|
|
||||||
break;
|
break;
|
||||||
}
|
|
||||||
|
|
||||||
iteration++;
|
|
||||||
TimingInfo timing = computeTiming(virtualGraph);
|
|
||||||
if (!timing.valid) {
|
|
||||||
if (virtualGraph.nodes.size() >= 200)
|
|
||||||
llvm::errs() << llvm::formatv(
|
|
||||||
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<size_t> selectedNodes;
|
|
||||||
auto criticalWindow =
|
|
||||||
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size()));
|
|
||||||
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
|
|
||||||
|
|
||||||
if (selectedNodes.size() < 2) {
|
|
||||||
if (virtualGraph.nodes.size() >= 200)
|
|
||||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
|
|
||||||
iteration,
|
|
||||||
virtualGraph.nodes.size(),
|
|
||||||
selectedNodes.size());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tryCoarsenSelectedNodes(selectedNodes))
|
|
||||||
continue;
|
|
||||||
if (virtualGraph.nodes.size() >= 200)
|
|
||||||
llvm::errs() << llvm::formatv(
|
|
||||||
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return buildResultFromVirtualGraph(virtualGraph, computeInstances);
|
return buildResultFromVirtualGraph(virtualGraph, spatComputes, edges);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
@@ -5,28 +5,15 @@
|
|||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/DenseSet.h"
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
// A scheduling identity that covers both spat.compute and scheduled shards of
|
|
||||||
// spat.compute_batch.
|
|
||||||
struct ComputeInstance {
|
|
||||||
mlir::Operation* op = nullptr;
|
|
||||||
uint32_t laneStart = 0;
|
|
||||||
uint32_t laneCount = 1;
|
|
||||||
|
|
||||||
bool operator==(const ComputeInstance& other) const {
|
|
||||||
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DCPAnalysisResult {
|
struct DCPAnalysisResult {
|
||||||
std::vector<ComputeInstance> dominanceOrderCompute;
|
std::vector<onnx_mlir::spatial::SpatCompute> dominanceOrderCompute;
|
||||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
llvm::DenseMap<onnx_mlir::spatial::SpatCompute, size_t> computeToCpuMap;
|
||||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
llvm::DenseSet<onnx_mlir::spatial::SpatCompute> isLastComputeOfCpu;
|
||||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatCompute> cpuToLastComputeMap;
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -47,21 +34,3 @@ public:
|
|||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|
||||||
namespace llvm {
|
|
||||||
template <>
|
|
||||||
struct DenseMapInfo<ComputeInstance> {
|
|
||||||
static ComputeInstance getEmptyKey() {
|
|
||||||
return {DenseMapInfo<mlir::Operation*>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
|
|
||||||
}
|
|
||||||
static ComputeInstance getTombstoneKey() {
|
|
||||||
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
|
||||||
}
|
|
||||||
static unsigned getHashValue(const ComputeInstance& v) {
|
|
||||||
return llvm::hash_combine(v.op, v.laneStart, v.laneCount);
|
|
||||||
}
|
|
||||||
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) {
|
|
||||||
return a == b;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace llvm
|
|
||||||
|
|||||||
@@ -38,14 +38,11 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/Support/ErrorHandling.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <cstdint>
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <queue>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "DCPAnalysis.hpp"
|
#include "DCPAnalysis.hpp"
|
||||||
@@ -63,7 +60,6 @@ namespace {
|
|||||||
// Coarse-grained phase timers printed when DCP_SELECT_PROFILE is set.
|
// Coarse-grained phase timers printed when DCP_SELECT_PROFILE is set.
|
||||||
struct SelectTimers {
|
struct SelectTimers {
|
||||||
double findSlot = 0.0;
|
double findSlot = 0.0;
|
||||||
double dedup = 0.0;
|
|
||||||
double precheck = 0.0;
|
double precheck = 0.0;
|
||||||
double snapshotInsertUpdate = 0.0;
|
double snapshotInsertUpdate = 0.0;
|
||||||
double childSlot = 0.0;
|
double childSlot = 0.0;
|
||||||
@@ -74,19 +70,9 @@ struct SelectTimers {
|
|||||||
long tasksProcessed = 0;
|
long tasksProcessed = 0;
|
||||||
void dump(const char* label) const {
|
void dump(const char* label) const {
|
||||||
std::fprintf(stderr,
|
std::fprintf(stderr,
|
||||||
"[selectProfile:%s] tasks=%ld dedup=%.2fs findSlot=%.2fs precheck=%.2fs snapUpd=%.2fs "
|
"[selectProfile:%s] tasks=%ld findSlot=%.2fs precheck=%.2fs snapUpd=%.2fs childSlot=%.2fs rollback=%.2fs iter=%ld precheckPass=%ld dcplPass=%ld\n",
|
||||||
"childSlot=%.2fs rollback=%.2fs iter=%ld precheckPass=%ld dcplPass=%ld\n",
|
label, tasksProcessed, findSlot, precheck, snapshotInsertUpdate, childSlot,
|
||||||
label,
|
rollbackRestore, iterations, passedPrecheck, passedDcpl);
|
||||||
tasksProcessed,
|
|
||||||
dedup,
|
|
||||||
findSlot,
|
|
||||||
precheck,
|
|
||||||
snapshotInsertUpdate,
|
|
||||||
childSlot,
|
|
||||||
rollbackRestore,
|
|
||||||
iterations,
|
|
||||||
passedPrecheck,
|
|
||||||
passedDcpl);
|
|
||||||
}
|
}
|
||||||
~SelectTimers() {
|
~SelectTimers() {
|
||||||
if (std::getenv("DCP_SELECT_PROFILE"))
|
if (std::getenv("DCP_SELECT_PROFILE"))
|
||||||
@@ -97,101 +83,6 @@ static SelectTimers gSelectTimers;
|
|||||||
} // namespace
|
} // namespace
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
uint64_t mixHash(uint64_t seed, uint64_t value) {
|
|
||||||
seed ^= value + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
|
|
||||||
return seed;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t finishHash(uint64_t seed) {
|
|
||||||
seed ^= seed >> 33;
|
|
||||||
seed *= 0xff51afd7ed558ccdULL;
|
|
||||||
seed ^= seed >> 33;
|
|
||||||
seed *= 0xc4ceb9fe1a85ec53ULL;
|
|
||||||
seed ^= seed >> 33;
|
|
||||||
return seed;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t hashEdgeSignature(uint64_t neighborHash, Weight weight, uint64_t direction) {
|
|
||||||
uint64_t hash = mixHash(0x84222325cbf29ce4ULL, direction);
|
|
||||||
hash = mixHash(hash, neighborHash);
|
|
||||||
hash = mixHash(hash, static_cast<uint64_t>(weight));
|
|
||||||
return finishHash(hash);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CpuAestCache {
|
|
||||||
Time defaultAest = 0;
|
|
||||||
llvm::SmallDenseMap<CPU, Time, 8> colocatedParentAests;
|
|
||||||
|
|
||||||
Time get(CPU cpu) const {
|
|
||||||
auto it = colocatedParentAests.find(cpu);
|
|
||||||
if (it == colocatedParentAests.end())
|
|
||||||
return defaultAest;
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct CpuTimeMax {
|
|
||||||
CPU cpu = -1;
|
|
||||||
Time time = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
void updateCpuTimeMax(CpuTimeMax& first, CpuTimeMax& second, CPU cpu, Time time) {
|
|
||||||
if (first.cpu == cpu) {
|
|
||||||
first.time = std::max(first.time, time);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (second.cpu == cpu) {
|
|
||||||
second.time = std::max(second.time, time);
|
|
||||||
if (second.time > first.time)
|
|
||||||
std::swap(first, second);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (time >= first.time) {
|
|
||||||
second = first;
|
|
||||||
first = {cpu, time};
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (time > second.time)
|
|
||||||
second = {cpu, time};
|
|
||||||
}
|
|
||||||
|
|
||||||
CpuAestCache computeCpuAestCache(TaskDCP* task) {
|
|
||||||
CpuAestCache cache;
|
|
||||||
llvm::SmallDenseMap<CPU, Time, 8> transferAestByCpu;
|
|
||||||
llvm::SmallDenseMap<CPU, Time, 8> localAestByCpu;
|
|
||||||
Time unscheduledTransferAest = 0;
|
|
||||||
|
|
||||||
for (const Edge& parentEdge : task->parents) {
|
|
||||||
Time parentFinish = addOrMax(parentEdge.first->getAest(), parentEdge.first->getWeight());
|
|
||||||
Time transferAest = addOrMax(parentFinish, getTransferCost(parentEdge.first, task));
|
|
||||||
if (std::optional<CPU> parentCpu = parentEdge.first->getCpu()) {
|
|
||||||
Time& cpuTransferAest = transferAestByCpu[*parentCpu];
|
|
||||||
cpuTransferAest = std::max(cpuTransferAest, transferAest);
|
|
||||||
Time& cpuLocalAest = localAestByCpu[*parentCpu];
|
|
||||||
cpuLocalAest = std::max(cpuLocalAest, parentFinish);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
unscheduledTransferAest = std::max(unscheduledTransferAest, transferAest);
|
|
||||||
}
|
|
||||||
|
|
||||||
CpuTimeMax firstOther {-1, unscheduledTransferAest};
|
|
||||||
CpuTimeMax secondOther {-1, 0};
|
|
||||||
for (const auto& entry : transferAestByCpu)
|
|
||||||
updateCpuTimeMax(firstOther, secondOther, entry.first, entry.second);
|
|
||||||
|
|
||||||
cache.defaultAest = firstOther.time;
|
|
||||||
for (const auto& entry : localAestByCpu) {
|
|
||||||
CPU cpu = entry.first;
|
|
||||||
Time bestNonLocalParentAest = firstOther.cpu == cpu ? secondOther.time : firstOther.time;
|
|
||||||
cache.colocatedParentAests[cpu] = std::max(bestNonLocalParentAest, entry.second);
|
|
||||||
}
|
|
||||||
return cache;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Edge manipulation
|
// Edge manipulation
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -265,49 +156,6 @@ std::vector<TaskDCP*> GraphDCP::getRoots() {
|
|||||||
return tmp;
|
return tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphDCP::initTaskStructureHashes() {
|
|
||||||
taskStructureHashes.resize(nodes.size());
|
|
||||||
for (auto [index, task] : llvm::enumerate(nodes)) {
|
|
||||||
uint64_t hash = mixHash(0x7442b1129fd01363ULL, static_cast<uint64_t>(task.getWeight()));
|
|
||||||
hash = mixHash(hash, static_cast<uint64_t>(task.getCrossbarUsage()));
|
|
||||||
taskStructureHashes[index] = finishHash(hash);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<uint64_t> nextHashes(nodes.size());
|
|
||||||
std::vector<uint64_t> edgeHashes;
|
|
||||||
for (int iteration = 0; iteration < 4; ++iteration) {
|
|
||||||
for (auto [index, task] : llvm::enumerate(nodes)) {
|
|
||||||
uint64_t hash = mixHash(0x464dcab27ac82291ULL, taskStructureHashes[index]);
|
|
||||||
edgeHashes.clear();
|
|
||||||
edgeHashes.reserve(task.parents.size() + task.children.size());
|
|
||||||
for (const Edge& parent : task.parents)
|
|
||||||
if (!parent.isScheduling)
|
|
||||||
edgeHashes.push_back(
|
|
||||||
hashEdgeSignature(taskStructureHashes[getNodeIndex(parent.first)], parent.second, /*direction=*/0));
|
|
||||||
for (const Edge& child : task.children)
|
|
||||||
if (!child.isScheduling)
|
|
||||||
edgeHashes.push_back(
|
|
||||||
hashEdgeSignature(taskStructureHashes[getNodeIndex(child.first)], child.second, /*direction=*/1));
|
|
||||||
llvm::sort(edgeHashes);
|
|
||||||
hash = mixHash(hash, static_cast<uint64_t>(edgeHashes.size()));
|
|
||||||
for (uint64_t edgeHash : edgeHashes)
|
|
||||||
hash = mixHash(hash, edgeHash);
|
|
||||||
nextHashes[index] = finishHash(hash);
|
|
||||||
}
|
|
||||||
taskStructureHashes.swap(nextHashes);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compact dedup key for CPU `c` vs `candidate`: mixes candidateAest, crossbar
|
|
||||||
// usage, and the incremental cpu structure hash. No heap allocation.
|
|
||||||
uint64_t GraphDCP::computeCpuCandidateKey(Time candidateAest, CPU cpu) {
|
|
||||||
uint64_t hash = mixHash(0xd6e8feb86659fd93ULL, static_cast<uint64_t>(candidateAest));
|
|
||||||
hash = mixHash(hash, static_cast<uint64_t>(getCpuCrossbarUsage(cpu)));
|
|
||||||
auto it = cpuStructureHashes.find(cpu);
|
|
||||||
hash = mixHash(hash, it != cpuStructureHashes.end() ? it->second : 0ULL);
|
|
||||||
return finishHash(hash);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Inserts `task` at `position` on `cpu`, wiring up scheduling edges with the
|
// Inserts `task` at `position` on `cpu`, wiring up scheduling edges with the
|
||||||
// neighbouring tasks and keeping the global topological order consistent.
|
// neighbouring tasks and keeping the global topological order consistent.
|
||||||
TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position) {
|
TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position) {
|
||||||
@@ -316,7 +164,6 @@ TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position)
|
|||||||
task->setCpu(cpu);
|
task->setCpu(cpu);
|
||||||
task->setWeight(scheduledWeight);
|
task->setWeight(scheduledWeight);
|
||||||
reserveTaskCrossbars(cpu, task);
|
reserveTaskCrossbars(cpu, task);
|
||||||
cpuStructureHashes[cpu] ^= taskStructureHashes[getNodeIndex(task)];
|
|
||||||
auto& tasksInCpu = getOrCreateCpuTasks(cpu);
|
auto& tasksInCpu = getOrCreateCpuTasks(cpu);
|
||||||
unsigned int numCpuTasks = tasksInCpu.size();
|
unsigned int numCpuTasks = tasksInCpu.size();
|
||||||
assert(position <= numCpuTasks && "Inserting in a not valid position");
|
assert(position <= numCpuTasks && "Inserting in a not valid position");
|
||||||
@@ -354,7 +201,6 @@ TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position)
|
|||||||
|
|
||||||
void GraphDCP::removeTaskFromCPU(CPU cpu, TaskDCP* task) {
|
void GraphDCP::removeTaskFromCPU(CPU cpu, TaskDCP* task) {
|
||||||
releaseTaskCrossbars(cpu, task);
|
releaseTaskCrossbars(cpu, task);
|
||||||
cpuStructureHashes[cpu] ^= taskStructureHashes[getNodeIndex(task)];
|
|
||||||
task->resetCpu();
|
task->resetCpu();
|
||||||
task->resetWeight();
|
task->resetWeight();
|
||||||
auto& scheduledTasks = getOrCreateCpuTasks(cpu);
|
auto& scheduledTasks = getOrCreateCpuTasks(cpu);
|
||||||
@@ -425,21 +271,6 @@ bool GraphDCP::wouldExhaustCrossbarCapacity(CPU cpu, const TaskDCP* task) const
|
|||||||
return nextUsage >= getCpuCrossbarCapacity();
|
return nextUsage >= getCpuCrossbarCapacity();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t GraphDCP::crossbarsUsed() const {
|
|
||||||
CrossbarUsage crossbarEdge = static_cast<CrossbarUsage>(onnx_mlir::crossbarSize.getValue());
|
|
||||||
CrossbarUsage crossbarArea = crossbarEdge * crossbarEdge;
|
|
||||||
if (crossbarArea == 0)
|
|
||||||
return 0;
|
|
||||||
CrossbarUsage totalArea = 0;
|
|
||||||
for (const auto& [cpu, usage] : cpuCrossbarUsage)
|
|
||||||
totalArea = checkedAdd(totalArea, usage);
|
|
||||||
return static_cast<size_t>(totalArea / crossbarArea);
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t GraphDCP::crossbarsAvailable() const {
|
|
||||||
return static_cast<size_t>(lastCpu) * onnx_mlir::crossbarCountInCore.getValue();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AEST / ALST computation
|
// AEST / ALST computation
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -625,9 +456,9 @@ void GraphDCP::updateAestFromTaskWithDescendants(TaskDCP* task, llvm::ArrayRef<T
|
|||||||
for (TaskDCP* descendant : descendantsTopoOrder)
|
for (TaskDCP* descendant : descendantsTopoOrder)
|
||||||
recomputeAest(descendant);
|
recomputeAest(descendant);
|
||||||
|
|
||||||
const bool oldMaxInvalidated =
|
const bool oldMaxInvalidated = maxCompletionTask != nullptr
|
||||||
maxCompletionTask != nullptr
|
&& (maxCompletionTask == task
|
||||||
&& (maxCompletionTask == task || llvm::is_contained(descendantsTopoOrder, maxCompletionTask));
|
|| llvm::is_contained(descendantsTopoOrder, maxCompletionTask));
|
||||||
if (oldMaxInvalidated) {
|
if (oldMaxInvalidated) {
|
||||||
// The pre-update max came from a modified task; its completion has moved
|
// The pre-update max came from a modified task; its completion has moved
|
||||||
// upward, so modifiedMaxCompletion is an upper bound covering it. The
|
// upward, so modifiedMaxCompletion is an upper bound covering it. The
|
||||||
@@ -692,9 +523,9 @@ bool GraphDCP::tryUpdateAestWithinBudget(TaskDCP* task,
|
|||||||
if (!process(descendant))
|
if (!process(descendant))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
const bool oldMaxInvalidated =
|
const bool oldMaxInvalidated = maxCompletionTask != nullptr
|
||||||
maxCompletionTask != nullptr
|
&& (maxCompletionTask == task
|
||||||
&& (maxCompletionTask == task || llvm::is_contained(descendantsTopoOrder, maxCompletionTask));
|
|| llvm::is_contained(descendantsTopoOrder, maxCompletionTask));
|
||||||
if (oldMaxInvalidated) {
|
if (oldMaxInvalidated) {
|
||||||
dcpl = modifiedMaxCompletion;
|
dcpl = modifiedMaxCompletion;
|
||||||
maxCompletion = modifiedMaxCompletion;
|
maxCompletion = modifiedMaxCompletion;
|
||||||
@@ -715,109 +546,6 @@ bool GraphDCP::tryUpdateAestWithinBudget(TaskDCP* task,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Incrementally refreshes ALST after `task` was placed. The set of nodes whose
|
|
||||||
// ALST is structurally affected by the insertion is exactly
|
|
||||||
// `relations.ancestors ∪ {task}`: the task's outgoing transfer costs to
|
|
||||||
// same-CPU real children become 0, and new scheduling edges create parent
|
|
||||||
// relationships between `task` and its same-CPU neighbors. Every other node
|
|
||||||
// keeps its relative distance to the sink boundary and only absorbs the
|
|
||||||
// signed DCPL delta captured between `oldDcpl` and the now-updated `dcpl`.
|
|
||||||
void GraphDCP::updateAlstFromScheduledTask(TaskDCP* task, const CandidateRelations& relations, Time oldDcpl) {
|
|
||||||
Time newDcpl = getDcpl();
|
|
||||||
// If the AEST update saturated dcpl (e.g. rescue placement on a
|
|
||||||
// crossbar-exhausted CPU sets task weight to UINT64_MAX), the shift delta
|
|
||||||
// would be meaningless. Fall back to a full recompute for this step only.
|
|
||||||
if (newDcpl == std::numeric_limits<Time>::max()) {
|
|
||||||
initAlst();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (newDcpl != oldDcpl) {
|
|
||||||
const bool increased = newDcpl > oldDcpl;
|
|
||||||
const Time delta = increased ? (newDcpl - oldDcpl) : (oldDcpl - newDcpl);
|
|
||||||
for (TaskDCP& node : topologicalOrder) {
|
|
||||||
if (&node == task || relations.ancestors.contains(&node))
|
|
||||||
continue;
|
|
||||||
Time alst = node.getAlst();
|
|
||||||
node.setAlst(increased ? addOrMax(alst, delta) : subtractOrZero(alst, delta));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto recomputeAlst = [&](TaskDCP* node) {
|
|
||||||
Time minAlst = std::numeric_limits<Time>::max();
|
|
||||||
if (!node->hasChildren())
|
|
||||||
minAlst = subtractOrZero(newDcpl, node->getWeight());
|
|
||||||
for (const Edge& childEdge : node->children)
|
|
||||||
minAlst = std::min(minAlst,
|
|
||||||
subtractOrZero(childEdge.first->getAlst(),
|
|
||||||
addOrMax(node->getWeight(), getTransferCost(node, childEdge.first))));
|
|
||||||
node->setAlst(minAlst);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Walk the backward cone with a pending-children counter so that every
|
|
||||||
// ancestor is recomputed only after all of its affected children have
|
|
||||||
// been refreshed. This is resilient to staleness in the global
|
|
||||||
// `topologicalOrder` relative to freshly added scheduling edges.
|
|
||||||
llvm::DenseSet<TaskDCP*> affected = relations.ancestors;
|
|
||||||
affected.insert(task);
|
|
||||||
llvm::DenseMap<TaskDCP*, int> pendingAffectedChildren;
|
|
||||||
pendingAffectedChildren.reserve(affected.size());
|
|
||||||
std::vector<TaskDCP*> worklist;
|
|
||||||
worklist.reserve(affected.size());
|
|
||||||
for (TaskDCP* node : affected) {
|
|
||||||
int count = 0;
|
|
||||||
for (const Edge& childEdge : node->children)
|
|
||||||
if (affected.contains(childEdge.first))
|
|
||||||
count++;
|
|
||||||
pendingAffectedChildren[node] = count;
|
|
||||||
if (count == 0)
|
|
||||||
worklist.push_back(node);
|
|
||||||
}
|
|
||||||
while (!worklist.empty()) {
|
|
||||||
TaskDCP* node = worklist.back();
|
|
||||||
worklist.pop_back();
|
|
||||||
recomputeAlst(node);
|
|
||||||
for (const Edge& parentEdge : node->parents) {
|
|
||||||
if (!affected.contains(parentEdge.first))
|
|
||||||
continue;
|
|
||||||
auto it = pendingAffectedChildren.find(parentEdge.first);
|
|
||||||
assert(it != pendingAffectedChildren.end());
|
|
||||||
if (--it->second == 0)
|
|
||||||
worklist.push_back(parentEdge.first);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Opt-in consistency check: verifies the incremental ALST result against a
|
|
||||||
// full initAlst() recomputation. Very expensive (O(V+E) per placement) - only
|
|
||||||
// enable when investigating suspected drift.
|
|
||||||
#ifdef DCP_DEBUG_CHECK_ALST
|
|
||||||
std::vector<Time> afterIncremental(nodes.size());
|
|
||||||
for (size_t i = 0; i < nodes.size(); ++i)
|
|
||||||
afterIncremental[i] = nodes[i].getAlst();
|
|
||||||
initAlst();
|
|
||||||
bool mismatched = false;
|
|
||||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
|
||||||
if (afterIncremental[i] != nodes[i].getAlst()) {
|
|
||||||
if (!mismatched) {
|
|
||||||
llvm::errs() << "[alst-mismatch] placed=" << getNodeIndex(task) << " oldDcpl=" << oldDcpl
|
|
||||||
<< " newDcpl=" << newDcpl << " ancestors={";
|
|
||||||
for (TaskDCP* a : relations.ancestors)
|
|
||||||
llvm::errs() << getNodeIndex(a) << ",";
|
|
||||||
llvm::errs() << "}\n";
|
|
||||||
mismatched = true;
|
|
||||||
}
|
|
||||||
llvm::errs() << " node=" << i << " incremental=" << afterIncremental[i] << " full=" << nodes[i].getAlst()
|
|
||||||
<< " weight=" << nodes[i].getWeight()
|
|
||||||
<< " cpu=" << (nodes[i].isScheduled() ? (int) *nodes[i].getCpu() : -1) << " children=[";
|
|
||||||
for (const Edge& e : nodes[i].children)
|
|
||||||
llvm::errs() << getNodeIndex(e.first) << (e.isScheduling ? "s" : "")
|
|
||||||
<< "(tc=" << getTransferCost(&nodes[i], e.first) << ",alst=" << e.first->getAlst() << "),";
|
|
||||||
llvm::errs() << "]\n";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes a localised ALST: only ancestors of the candidate (plus the
|
// Computes a localised ALST: only ancestors of the candidate (plus the
|
||||||
// candidate itself) get recomputed, every other task keeps its current ALST.
|
// candidate itself) get recomputed, every other task keeps its current ALST.
|
||||||
// Processes nodes in reverse dependency order using a pending-children
|
// Processes nodes in reverse dependency order using a pending-children
|
||||||
@@ -1177,6 +905,32 @@ GraphDCP::FindSlot GraphDCP::findSlotWithFixedFinalTime(
|
|||||||
// Candidate selection and processor assignment
|
// Candidate selection and processor assignment
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Lowest slack wins; earliest AEST breaks ties. Critical-path tasks (zero
|
||||||
|
// slack) naturally float to the front.
|
||||||
|
TaskDCP* GraphDCP::findCandidate(const std::vector<TaskDCP*>& readyNodes) {
|
||||||
|
auto findBestNode = [](auto lft, auto rgt) {
|
||||||
|
Time leftSlack = slackOrZero((*lft)->getAest(), (*lft)->getAlst());
|
||||||
|
Time rightSlack = slackOrZero((*rgt)->getAest(), (*rgt)->getAlst());
|
||||||
|
if (leftSlack < rightSlack)
|
||||||
|
return lft;
|
||||||
|
if (rightSlack < leftSlack)
|
||||||
|
return rgt;
|
||||||
|
if ((*lft)->getAest() < (*rgt)->getAest())
|
||||||
|
return lft;
|
||||||
|
return rgt;
|
||||||
|
};
|
||||||
|
|
||||||
|
assert(!readyNodes.empty() && "expected at least one ready node");
|
||||||
|
auto validNode = readyNodes.begin();
|
||||||
|
auto bestNode = validNode;
|
||||||
|
|
||||||
|
while (validNode != readyNodes.end()) {
|
||||||
|
bestNode = findBestNode(validNode, bestNode);
|
||||||
|
std::advance(validNode, 1);
|
||||||
|
}
|
||||||
|
return *bestNode;
|
||||||
|
}
|
||||||
|
|
||||||
// Picks the best CPU + slot for `candidate`:
|
// Picks the best CPU + slot for `candidate`:
|
||||||
// * Phase 1 (parallel, read-only): call findSlot on every candidate CPU.
|
// * Phase 1 (parallel, read-only): call findSlot on every candidate CPU.
|
||||||
// * Phase 2 (sequential): process CPUs in ascending slot.aest order. For
|
// * Phase 2 (sequential): process CPUs in ascending slot.aest order. For
|
||||||
@@ -1185,7 +939,7 @@ GraphDCP::FindSlot GraphDCP::findSlotWithFixedFinalTime(
|
|||||||
// evaluate a slot for the smallest-slack child, then roll back.
|
// evaluate a slot for the smallest-slack child, then roll back.
|
||||||
// * Rescue (sequential): if nothing fit, grow the CPU count if allowed,
|
// * Rescue (sequential): if nothing fit, grow the CPU count if allowed,
|
||||||
// otherwise pick the CPU that leads to the smallest DCPL increase.
|
// otherwise pick the CPU that leads to the smallest DCPL increase.
|
||||||
GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
||||||
CandidateRelations relations = dcp_graph::computeCandidateRelations(candidate);
|
CandidateRelations relations = dcp_graph::computeCandidateRelations(candidate);
|
||||||
relations.descendantsTopoOrder.reserve(relations.descendants.size());
|
relations.descendantsTopoOrder.reserve(relations.descendants.size());
|
||||||
for (auto it = candidate->getTopologicalIterator(); it != topologicalOrder.end(); ++it) {
|
for (auto it = candidate->getTopologicalIterator(); it != topologicalOrder.end(); ++it) {
|
||||||
@@ -1205,43 +959,22 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
|
|||||||
const CrossbarUsage candidateFootprint = getTaskCrossbarFootprint(candidate);
|
const CrossbarUsage candidateFootprint = getTaskCrossbarFootprint(candidate);
|
||||||
const bool candidateHasCrossbar = candidateFootprint != 0;
|
const bool candidateHasCrossbar = candidateFootprint != 0;
|
||||||
const CrossbarUsage cpuCapacity = candidateHasCrossbar ? getCpuCrossbarCapacity() : 0;
|
const CrossbarUsage cpuCapacity = candidateHasCrossbar ? getCpuCrossbarCapacity() : 0;
|
||||||
DCP_DEBUG_IF(auto dedupStart = std::chrono::steady_clock::now();)
|
|
||||||
CpuAestCache cpuAests = computeCpuAestCache(candidate);
|
|
||||||
DCP_DEBUG_IF(const bool checkCpuAestCache = std::getenv("DCP_CHECK_CPU_AEST_CACHE") != nullptr;)
|
|
||||||
llvm::SmallDenseSet<uint64_t, 32> seenProcessorKeys;
|
|
||||||
seenProcessorKeys.reserve(static_cast<size_t>(topCpu + 1));
|
|
||||||
for (CPU c = 0; c <= topCpu; c++) {
|
for (CPU c = 0; c <= topCpu; c++) {
|
||||||
if (candidateHasCrossbar && c != getLastCpu()) {
|
if (candidateHasCrossbar && c != getLastCpu()) {
|
||||||
CrossbarUsage nextUsage = checkedAdd(getCpuCrossbarUsage(c), candidateFootprint);
|
CrossbarUsage nextUsage = checkedAdd(getCpuCrossbarUsage(c), candidateFootprint);
|
||||||
if (nextUsage >= cpuCapacity)
|
if (nextUsage >= cpuCapacity)
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Time candidateAest = cpuAests.get(c);
|
|
||||||
DCP_DEBUG_IF(if (checkCpuAestCache) {
|
|
||||||
Time recomputedAest = computeAestOnCpu(candidate, c);
|
|
||||||
if (candidateAest != recomputedAest) {
|
|
||||||
std::fprintf(stderr,
|
|
||||||
"[DCP_CHECK_CPU_AEST_CACHE] mismatch candidate=%zu cpu=%d cached=%llu recomputed=%llu\n",
|
|
||||||
getNodeIndex(candidate),
|
|
||||||
c,
|
|
||||||
static_cast<unsigned long long>(candidateAest),
|
|
||||||
static_cast<unsigned long long>(recomputedAest));
|
|
||||||
llvm::report_fatal_error("DCP CPU AEST cache mismatch");
|
|
||||||
}
|
|
||||||
})
|
|
||||||
if (!seenProcessorKeys.insert(computeCpuCandidateKey(candidateAest, c)).second)
|
|
||||||
continue;
|
|
||||||
processors.push_back(c);
|
processors.push_back(c);
|
||||||
}
|
}
|
||||||
DCP_DEBUG_IF(gSelectTimers.dedup +=
|
|
||||||
std::chrono::duration<double>(std::chrono::steady_clock::now() - dedupStart).count();)
|
|
||||||
if (processors.empty()) {
|
if (processors.empty()) {
|
||||||
// processors.empty() implies !canCreateNewCpu: a fresh CPU always passes
|
CPU bestCpu = canCreateNewCpu ? getLastCpu() : 0;
|
||||||
// the crossbar filter and would have been added. Reaching here means every
|
FindSlot bestSlot = {computeAestOnCpu(candidate, bestCpu), static_cast<int>(getOrCreateCpuTasks(bestCpu).size())};
|
||||||
// existing CPU is crossbar-exhausted and the task requires crossbar
|
if (canCreateNewCpu)
|
||||||
// capacity — the placement is impossible.
|
incrementLastCpu();
|
||||||
llvm::report_fatal_error("DCP scheduler: crossbar capacity exhausted on all CPUs; "
|
insertTaskInCPU(bestCpu, candidate, bestSlot.index);
|
||||||
"cannot schedule task that requires crossbar allocation");
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 1: parallel findSlot sweep (read-only over graph state).
|
// Phase 1: parallel findSlot sweep (read-only over graph state).
|
||||||
@@ -1267,20 +1000,21 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
|
|||||||
for (size_t i = 0; i < processors.size(); ++i)
|
for (size_t i = 0; i < processors.size(); ++i)
|
||||||
sweep(i);
|
sweep(i);
|
||||||
DCP_DEBUG_IF(gSelectTimers.findSlot +=
|
DCP_DEBUG_IF(gSelectTimers.findSlot +=
|
||||||
std::chrono::duration<double>(std::chrono::steady_clock::now() - sweepStart).count();)
|
std::chrono::duration<double>(std::chrono::steady_clock::now() - sweepStart).count();)
|
||||||
|
|
||||||
#ifdef DCP_DEBUG_ENABLED
|
#ifdef DCP_DEBUG_ENABLED
|
||||||
{
|
{
|
||||||
static bool reported = false;
|
static bool reported = false;
|
||||||
if (!reported) {
|
if (!reported) {
|
||||||
reported = true;
|
reported = true;
|
||||||
std::fprintf(
|
std::fprintf(stderr,
|
||||||
stderr,
|
"[dcp] selectProcessor parallel sweep: context=%p mt=%d procs=%zu pool=%u\n",
|
||||||
"[dcp] selectProcessor parallel sweep: context=%p mt=%d procs=%zu pool=%u\n",
|
(void*) context,
|
||||||
(void*) context,
|
context != nullptr ? (int) context->isMultithreadingEnabled() : -1,
|
||||||
context != nullptr ? (int) context->isMultithreadingEnabled() : -1,
|
processors.size(),
|
||||||
processors.size(),
|
context != nullptr && context->isMultithreadingEnabled()
|
||||||
context != nullptr && context->isMultithreadingEnabled() ? context->getThreadPool().getMaxConcurrency() : 0u);
|
? context->getThreadPool().getMaxConcurrency()
|
||||||
|
: 0u);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@@ -1321,10 +1055,9 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
|
|||||||
DCP_DEBUG_IF(auto t2 = std::chrono::steady_clock::now();)
|
DCP_DEBUG_IF(auto t2 = std::chrono::steady_clock::now();)
|
||||||
Weight candidateWeight = candidate->computeWeightOnCpu(this, currentCpu);
|
Weight candidateWeight = candidate->computeWeightOnCpu(this, currentCpu);
|
||||||
Time candidateCompletion = addOrMax(slot.aest, candidateWeight);
|
Time candidateCompletion = addOrMax(slot.aest, candidateWeight);
|
||||||
bool skip =
|
bool skip = (!emptyCpu && candidateCompletion > currentDcpl)
|
||||||
(!emptyCpu && candidateCompletion > currentDcpl) || addOrMax(slot.aest, candidateCompletion) >= bestComposite;
|
|| addOrMax(slot.aest, candidateCompletion) >= bestComposite;
|
||||||
DCP_DEBUG_IF(gSelectTimers.precheck +=
|
DCP_DEBUG_IF(gSelectTimers.precheck += std::chrono::duration<double>(std::chrono::steady_clock::now() - t2).count();)
|
||||||
std::chrono::duration<double>(std::chrono::steady_clock::now() - t2).count();)
|
|
||||||
if (skip)
|
if (skip)
|
||||||
continue;
|
continue;
|
||||||
DCP_DEBUG_IF(++gSelectTimers.passedPrecheck;)
|
DCP_DEBUG_IF(++gSelectTimers.passedPrecheck;)
|
||||||
@@ -1340,8 +1073,8 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
|
|||||||
scheduleSnapshot = dcp_graph::captureLocalScheduleState(
|
scheduleSnapshot = dcp_graph::captureLocalScheduleState(
|
||||||
candidate, relations.descendants, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
|
candidate, relations.descendants, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
|
||||||
taskInsertion = insertTaskInCPU(currentCpu, candidate, slot.index);
|
taskInsertion = insertTaskInCPU(currentCpu, candidate, slot.index);
|
||||||
bool withinBudget =
|
bool withinBudget = tryUpdateAestWithinBudget(
|
||||||
tryUpdateAestWithinBudget(candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder), currentDcpl);
|
candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder), currentDcpl);
|
||||||
if (!withinBudget) {
|
if (!withinBudget) {
|
||||||
DCP_DEBUG_IF(auto t4 = std::chrono::steady_clock::now();)
|
DCP_DEBUG_IF(auto t4 = std::chrono::steady_clock::now();)
|
||||||
taskInsertion.rollBack();
|
taskInsertion.rollBack();
|
||||||
@@ -1354,7 +1087,7 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
DCP_DEBUG_IF(gSelectTimers.snapshotInsertUpdate +=
|
DCP_DEBUG_IF(gSelectTimers.snapshotInsertUpdate +=
|
||||||
std::chrono::duration<double>(std::chrono::steady_clock::now() - t3).count();)
|
std::chrono::duration<double>(std::chrono::steady_clock::now() - t3).count();)
|
||||||
DCP_DEBUG_IF(++gSelectTimers.passedDcpl;)
|
DCP_DEBUG_IF(++gSelectTimers.passedDcpl;)
|
||||||
|
|
||||||
// Pick the tightest unscheduled child (smallest slack) and measure what
|
// Pick the tightest unscheduled child (smallest slack) and measure what
|
||||||
@@ -1402,7 +1135,7 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
|
|||||||
dcp_graph::restoreLocalScheduleState(
|
dcp_graph::restoreLocalScheduleState(
|
||||||
scheduleSnapshot, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
|
scheduleSnapshot, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
|
||||||
DCP_DEBUG_IF(gSelectTimers.rollbackRestore +=
|
DCP_DEBUG_IF(gSelectTimers.rollbackRestore +=
|
||||||
std::chrono::duration<double>(std::chrono::steady_clock::now() - t6).count();)
|
std::chrono::duration<double>(std::chrono::steady_clock::now() - t6).count();)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1417,9 +1150,7 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
|
|||||||
else {
|
else {
|
||||||
Time bestDcpl = std::numeric_limits<Time>::max();
|
Time bestDcpl = std::numeric_limits<Time>::max();
|
||||||
Time currentDcpl = getDcpl();
|
Time currentDcpl = getDcpl();
|
||||||
for (CPU c : processors) {
|
for (CPU c = 0; c < getLastCpu(); c++) {
|
||||||
if (c == getLastCpu())
|
|
||||||
continue;
|
|
||||||
auto slot = findSlot(candidate, c, false, relations);
|
auto slot = findSlot(candidate, c, false, relations);
|
||||||
if (slot.aest == std::numeric_limits<Time>::max())
|
if (slot.aest == std::numeric_limits<Time>::max())
|
||||||
slot = findSlot(candidate, c, true, relations);
|
slot = findSlot(candidate, c, true, relations);
|
||||||
@@ -1428,7 +1159,8 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
|
|||||||
// Cheap lower bound: post-insertion DCPL is at least max(currentDcpl,
|
// Cheap lower bound: post-insertion DCPL is at least max(currentDcpl,
|
||||||
// candidate completion on this slot). Skip CPUs already worse than
|
// candidate completion on this slot). Skip CPUs already worse than
|
||||||
// the best seen.
|
// the best seen.
|
||||||
Time lowerBound = std::max(currentDcpl, addOrMax(slot.aest, candidate->computeWeightOnCpu(this, c)));
|
Time lowerBound =
|
||||||
|
std::max(currentDcpl, addOrMax(slot.aest, candidate->computeWeightOnCpu(this, c)));
|
||||||
if (lowerBound >= bestDcpl)
|
if (lowerBound >= bestDcpl)
|
||||||
continue;
|
continue;
|
||||||
auto snapshot = dcp_graph::captureLocalScheduleState(
|
auto snapshot = dcp_graph::captureLocalScheduleState(
|
||||||
@@ -1437,37 +1169,23 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
|
|||||||
updateAestFromTaskWithDescendants(candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder));
|
updateAestFromTaskWithDescendants(candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder));
|
||||||
Time candidateDcpl = getDcpl();
|
Time candidateDcpl = getDcpl();
|
||||||
taskInsertion.rollBack();
|
taskInsertion.rollBack();
|
||||||
dcp_graph::restoreLocalScheduleState(snapshot, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
|
dcp_graph::restoreLocalScheduleState(
|
||||||
|
snapshot, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
|
||||||
if (candidateDcpl < bestDcpl) {
|
if (candidateDcpl < bestDcpl) {
|
||||||
bestDcpl = candidateDcpl;
|
bestDcpl = candidateDcpl;
|
||||||
bestCpu = c;
|
bestCpu = c;
|
||||||
bestSlot = slot;
|
bestSlot = slot;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (bestCpu == -1)
|
if (bestCpu == -1) {
|
||||||
llvm::report_fatal_error("DCP scheduler: no valid slot found for task on any eligible CPU — "
|
bestCpu = 0;
|
||||||
"all slots are blocked by already-placed descendants");
|
bestSlot = {computeAestOnCpu(candidate, bestCpu), static_cast<int>(getOrCreateCpuTasks(bestCpu).size())};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (bestCpu == getLastCpu() && getLastCpu() < maxCpuCount)
|
if (bestCpu == getLastCpu() && getLastCpu() < maxCpuCount)
|
||||||
incrementLastCpu();
|
incrementLastCpu();
|
||||||
insertTaskInCPU(bestCpu, candidate, bestSlot.index);
|
insertTaskInCPU(bestCpu, candidate, bestSlot.index);
|
||||||
|
|
||||||
// Incremental AEST/ALST refresh replacing the full initAest/initAlst that
|
|
||||||
// used to run after every placement. Post-insertion relations pick up any
|
|
||||||
// new scheduling-edge ancestors/descendants introduced by the insertion.
|
|
||||||
Time oldDcpl = getDcpl();
|
|
||||||
CandidateRelations postRelations = dcp_graph::computeCandidateRelations(candidate);
|
|
||||||
llvm::SmallVector<TaskDCP*, 32> postDescendantsTopoOrder;
|
|
||||||
postDescendantsTopoOrder.reserve(postRelations.descendants.size());
|
|
||||||
for (auto it = candidate->getTopologicalIterator(); it != topologicalOrder.end(); ++it) {
|
|
||||||
TaskDCP* current = &*it;
|
|
||||||
if (current != candidate && postRelations.descendants.contains(current))
|
|
||||||
postDescendantsTopoOrder.push_back(current);
|
|
||||||
}
|
|
||||||
updateAestFromTaskWithDescendants(candidate, llvm::ArrayRef<TaskDCP*>(postDescendantsTopoOrder));
|
|
||||||
updateAlstFromScheduledTask(candidate, postRelations, oldDcpl);
|
|
||||||
return postRelations;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -1476,102 +1194,61 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
|
|||||||
|
|
||||||
void GraphDCP::runDcp() {
|
void GraphDCP::runDcp() {
|
||||||
initTopological();
|
initTopological();
|
||||||
initTaskStructureHashes();
|
|
||||||
initAest();
|
initAest();
|
||||||
initAlst();
|
initAlst();
|
||||||
dumpDot();
|
dumpDot();
|
||||||
|
|
||||||
dcp_graph::DcpProgressLogger progressLogger(nodes.size());
|
dcp_graph::DcpProgressLogger progressLogger(nodes.size());
|
||||||
llvm::DenseMap<TaskDCP*, int> unscheduledParents;
|
llvm::DenseMap<TaskDCP*, int> unscheduledParents;
|
||||||
|
std::vector<TaskDCP*> readyNodes;
|
||||||
// Min-heap over ready tasks: tightest slack first, earliest AEST as tiebreak.
|
readyNodes.reserve(nodes.size());
|
||||||
// Lazy deletion: when AEST/ALST change after a placement, fresh entries are
|
|
||||||
// pushed for the affected tasks. Stale ones are detected on pop by comparing
|
|
||||||
// stored vs current (slack, aest) and re-pushed with the current values.
|
|
||||||
struct ReadyEntry {
|
|
||||||
Time slack;
|
|
||||||
Time aest;
|
|
||||||
int64_t orderKey;
|
|
||||||
TaskDCP* task;
|
|
||||||
bool operator>(const ReadyEntry& other) const {
|
|
||||||
if (slack != other.slack)
|
|
||||||
return slack > other.slack;
|
|
||||||
if (aest != other.aest)
|
|
||||||
return aest > other.aest;
|
|
||||||
return orderKey > other.orderKey;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
std::priority_queue<ReadyEntry, std::vector<ReadyEntry>, std::greater<ReadyEntry>> readyQueue;
|
|
||||||
size_t readyCount = 0;
|
|
||||||
|
|
||||||
auto pushReady = [&](TaskDCP* node) {
|
|
||||||
readyQueue.push({slackOrZero(node->getAest(), node->getAlst()), node->getAest(), node->Id(), node});
|
|
||||||
};
|
|
||||||
|
|
||||||
for (auto& node : nodes) {
|
for (auto& node : nodes) {
|
||||||
int dependencyParents = dcp_graph::countDependencyParents(&node);
|
int dependencyParents = dcp_graph::countDependencyParents(&node);
|
||||||
unscheduledParents[&node] = dependencyParents;
|
unscheduledParents[&node] = dependencyParents;
|
||||||
if (dependencyParents == 0) {
|
if (dependencyParents == 0)
|
||||||
pushReady(&node);
|
readyNodes.push_back(&node);
|
||||||
++readyCount;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
size_t xbarsCapacity = static_cast<size_t>(maxCpuCount) * onnx_mlir::crossbarCountInCore.getValue();
|
progressLogger.printStart(readyNodes.size());
|
||||||
progressLogger.printStart(readyCount, maxCpuCount, xbarsCapacity);
|
|
||||||
|
|
||||||
while (readyCount > 0) {
|
while (!readyNodes.empty()) {
|
||||||
// Pop with lazy deletion: skip stale entries and re-push with current values.
|
DCP_DEBUG_IF(auto findStart = std::chrono::steady_clock::now();)
|
||||||
TaskDCP* candidate = nullptr;
|
TaskDCP* candidate = findCandidate(readyNodes);
|
||||||
while (!readyQueue.empty()) {
|
DCP_DEBUG_IF(progressLogger.recordFindDuration(
|
||||||
auto entry = readyQueue.top();
|
std::chrono::duration<double>(std::chrono::steady_clock::now() - findStart).count());)
|
||||||
readyQueue.pop();
|
fastRemove(readyNodes, candidate);
|
||||||
Time curSlack = slackOrZero(entry.task->getAest(), entry.task->getAlst());
|
|
||||||
Time curAest = entry.task->getAest();
|
|
||||||
if (entry.slack == curSlack && entry.aest == curAest) {
|
|
||||||
candidate = entry.task;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
readyQueue.push({curSlack, curAest, entry.orderKey, entry.task});
|
|
||||||
}
|
|
||||||
assert(candidate != nullptr && "readyCount > 0 but heap exhausted");
|
|
||||||
--readyCount;
|
|
||||||
|
|
||||||
DCP_DEBUG_IF(auto selectStart = std::chrono::steady_clock::now();)
|
DCP_DEBUG_IF(auto selectStart = std::chrono::steady_clock::now();)
|
||||||
CandidateRelations postRelations = selectProcessor(candidate, candidate->isCriticalPath());
|
selectProcessor(candidate, candidate->isCriticalPath());
|
||||||
DCP_DEBUG_IF(
|
DCP_DEBUG_IF(
|
||||||
double selectSeconds = std::chrono::duration<double>(std::chrono::steady_clock::now() - selectStart).count();
|
double selectSeconds = std::chrono::duration<double>(std::chrono::steady_clock::now() - selectStart).count();
|
||||||
progressLogger.recordSelectDuration(selectSeconds);
|
progressLogger.recordSelectDuration(selectSeconds);
|
||||||
progressLogger.maybePrintSlowCandidate(getNodeIndex(candidate), selectSeconds, readyCount, getLastCpu());)
|
progressLogger.maybePrintSlowCandidate(getNodeIndex(candidate), selectSeconds, readyNodes.size(), getLastCpu());
|
||||||
|
)
|
||||||
// Proactively refresh the heap priority for ready nodes whose AEST or ALST
|
|
||||||
// changed: ancestors had their ALST individually recomputed; descendants had
|
|
||||||
// their AEST bumped. Both may now sort differently than their stale entries.
|
|
||||||
for (TaskDCP* node : postRelations.ancestors)
|
|
||||||
if (!node->isScheduled() && unscheduledParents[node] == 0)
|
|
||||||
pushReady(node);
|
|
||||||
for (TaskDCP* node : postRelations.descendants)
|
|
||||||
if (!node->isScheduled() && unscheduledParents[node] == 0)
|
|
||||||
pushReady(node);
|
|
||||||
|
|
||||||
|
DCP_DEBUG_IF(auto updateStart = std::chrono::steady_clock::now();)
|
||||||
|
initAest();
|
||||||
|
initAlst();
|
||||||
|
DCP_DEBUG_IF(progressLogger.recordUpdateDuration(
|
||||||
|
std::chrono::duration<double>(std::chrono::steady_clock::now() - updateStart).count());)
|
||||||
progressLogger.advanceCompleted();
|
progressLogger.advanceCompleted();
|
||||||
progressLogger.printProgress(readyCount, getLastCpu(), maxCpuCount, crossbarsUsed(), crossbarsAvailable(), false);
|
progressLogger.printProgress(readyNodes.size(), getLastCpu(), "recompute", false);
|
||||||
|
|
||||||
for (const auto& childEdge : candidate->children) {
|
for (const auto& childEdge : candidate->children) {
|
||||||
if (childEdge.isScheduling || childEdge.first->isScheduled())
|
if (childEdge.isScheduling || childEdge.first->isScheduled())
|
||||||
continue;
|
continue;
|
||||||
int& dependencyParents = unscheduledParents[childEdge.first];
|
int& dependencyParents = unscheduledParents[childEdge.first];
|
||||||
assert(dependencyParents > 0 && "dependency parent count must stay positive");
|
assert(dependencyParents > 0 && "dependency parent count must stay positive");
|
||||||
--dependencyParents;
|
dependencyParents--;
|
||||||
if (dependencyParents == 0) {
|
if (dependencyParents == 0)
|
||||||
pushReady(childEdge.first);
|
readyNodes.push_back(childEdge.first);
|
||||||
++readyCount;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
DCP_DEBUG_IF(++gSelectTimers.tasksProcessed;
|
DCP_DEBUG_IF(
|
||||||
if (std::getenv("DCP_SELECT_PROFILE") && (gSelectTimers.tasksProcessed % 100 == 0))
|
++gSelectTimers.tasksProcessed;
|
||||||
gSelectTimers.dump("tick");)
|
if (std::getenv("DCP_SELECT_PROFILE") && (gSelectTimers.tasksProcessed % 100 == 0))
|
||||||
|
gSelectTimers.dump("tick");
|
||||||
|
)
|
||||||
}
|
}
|
||||||
progressLogger.printProgress(0, getLastCpu(), maxCpuCount, crossbarsUsed(), crossbarsAvailable(), true);
|
progressLogger.printProgress(readyNodes.size(), getLastCpu(), "done", true);
|
||||||
dumpDot();
|
dumpDot();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1582,11 +1259,8 @@ DCPAnalysisResult GraphDCP::getResult() {
|
|||||||
|
|
||||||
auto dominanceOrder = dcp_graph::collectDominanceOrder(getRoots(), nodes.size());
|
auto dominanceOrder = dcp_graph::collectDominanceOrder(getRoots(), nodes.size());
|
||||||
ret.dominanceOrderCompute.reserve(dominanceOrder.size());
|
ret.dominanceOrderCompute.reserve(dominanceOrder.size());
|
||||||
for (auto elem : dominanceOrder) {
|
for (auto elem : dominanceOrder)
|
||||||
auto spatCompute = elem->getSpatCompute();
|
ret.dominanceOrderCompute.push_back(elem->getSpatCompute());
|
||||||
if (spatCompute)
|
|
||||||
ret.dominanceOrderCompute.push_back({spatCompute.getOperation(), 0});
|
|
||||||
}
|
|
||||||
|
|
||||||
for (CPU cpu = 0; cpu < getLastCpu(); ++cpu) {
|
for (CPU cpu = 0; cpu < getLastCpu(); ++cpu) {
|
||||||
const CpuTaskList* tasks = findCpuTasks(cpu);
|
const CpuTaskList* tasks = findCpuTasks(cpu);
|
||||||
@@ -1594,14 +1268,10 @@ DCPAnalysisResult GraphDCP::getResult() {
|
|||||||
continue;
|
continue;
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
for (auto node : *tasks) {
|
for (auto node : *tasks) {
|
||||||
auto spatCompute = node->getSpatCompute();
|
ret.computeToCpuMap[node->getSpatCompute()] = cpu;
|
||||||
if (!spatCompute)
|
|
||||||
continue;
|
|
||||||
ComputeInstance instance {spatCompute.getOperation(), 0};
|
|
||||||
ret.computeToCpuMap[instance] = cpu;
|
|
||||||
if (i++ == tasks->size() - 1) {
|
if (i++ == tasks->size() - 1) {
|
||||||
ret.isLastComputeOfCpu.insert(instance);
|
ret.isLastComputeOfCpu.insert(node->getSpatCompute());
|
||||||
ret.cpuToLastComputeMap[cpu] = instance;
|
ret.cpuToLastComputeMap[cpu] = node->getSpatCompute();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/DenseSet.h"
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@@ -49,10 +48,8 @@ private:
|
|||||||
|
|
||||||
std::vector<TaskDCP> nodes;
|
std::vector<TaskDCP> nodes;
|
||||||
onnx_mlir::LabeledList<TaskDCP> topologicalOrder;
|
onnx_mlir::LabeledList<TaskDCP> topologicalOrder;
|
||||||
std::vector<uint64_t> taskStructureHashes;
|
|
||||||
std::vector<CpuTaskList> cpuTasks;
|
std::vector<CpuTaskList> cpuTasks;
|
||||||
std::unordered_map<CPU, CrossbarUsage> cpuCrossbarUsage;
|
std::unordered_map<CPU, CrossbarUsage> cpuCrossbarUsage;
|
||||||
llvm::DenseMap<CPU, uint64_t> cpuStructureHashes;
|
|
||||||
CPU lastCpu = 0;
|
CPU lastCpu = 0;
|
||||||
long long flag = 1;
|
long long flag = 1;
|
||||||
Time dcpl = 0;
|
Time dcpl = 0;
|
||||||
@@ -73,7 +70,6 @@ private:
|
|||||||
|
|
||||||
void initAest();
|
void initAest();
|
||||||
void initAlst();
|
void initAlst();
|
||||||
void initTaskStructureHashes();
|
|
||||||
|
|
||||||
Time computeAestOnCpu(TaskDCP* task, CPU cpu);
|
Time computeAestOnCpu(TaskDCP* task, CPU cpu);
|
||||||
Time computeDcplOnCpu(TaskDCP* task, CPU cpu);
|
Time computeDcplOnCpu(TaskDCP* task, CPU cpu);
|
||||||
@@ -87,15 +83,9 @@ private:
|
|||||||
// `dcplBudget`, signalling that the new DCPL would exceed the budget.
|
// `dcplBudget`, signalling that the new DCPL would exceed the budget.
|
||||||
// Returns true iff the full propagation completed without exceeding the
|
// Returns true iff the full propagation completed without exceeding the
|
||||||
// budget. Uses the caller's snapshot to restore AEST on the aborted tail.
|
// budget. Uses the caller's snapshot to restore AEST on the aborted tail.
|
||||||
bool tryUpdateAestWithinBudget(TaskDCP* task, llvm::ArrayRef<TaskDCP*> descendantsTopoOrder, Time dcplBudget);
|
bool tryUpdateAestWithinBudget(TaskDCP* task,
|
||||||
|
llvm::ArrayRef<TaskDCP*> descendantsTopoOrder,
|
||||||
// Incrementally refreshes ALST after `task` has been scheduled. Nodes
|
Time dcplBudget);
|
||||||
// outside the backward cone (`relations.ancestors` plus `task`) retain
|
|
||||||
// their relative distance to the sink boundary and only absorb the signed
|
|
||||||
// DCPL delta (`newDcpl - oldDcpl`). `task` itself and its ancestors are
|
|
||||||
// recomputed in reverse topological order so that new same-CPU transfer
|
|
||||||
// costs (now zero) and scheduling-edge children are reflected.
|
|
||||||
void updateAlstFromScheduledTask(TaskDCP* task, const CandidateRelations& relations, Time oldDcpl);
|
|
||||||
|
|
||||||
void initTopological();
|
void initTopological();
|
||||||
void topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint, TaskInsertion* insertion = nullptr);
|
void topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint, TaskInsertion* insertion = nullptr);
|
||||||
@@ -104,11 +94,8 @@ private:
|
|||||||
llvm::DenseMap<TaskDCP*, Time> computeAlst(TaskDCP* task, CPU cpu, const CandidateRelations& relations);
|
llvm::DenseMap<TaskDCP*, Time> computeAlst(TaskDCP* task, CPU cpu, const CandidateRelations& relations);
|
||||||
size_t getNodeIndex(const TaskDCP* task) const;
|
size_t getNodeIndex(const TaskDCP* task) const;
|
||||||
|
|
||||||
// Returns a compact dedup key for CPU `c` when evaluating `candidate`:
|
TaskDCP* findCandidate(const std::vector<TaskDCP*>& readyNodes);
|
||||||
// mixes candidateAest, crossbar usage, and the incremental cpu structure
|
void selectProcessor(TaskDCP* candidate, bool push);
|
||||||
// hash into a single uint64_t. Zero heap allocation.
|
|
||||||
uint64_t computeCpuCandidateKey(Time candidateAest, CPU cpu);
|
|
||||||
CandidateRelations selectProcessor(TaskDCP* candidate, bool push);
|
|
||||||
CPU getLastCpu() const { return lastCpu; }
|
CPU getLastCpu() const { return lastCpu; }
|
||||||
void incrementLastCpu() { lastCpu++; }
|
void incrementLastCpu() { lastCpu++; }
|
||||||
FindSlot findSlot(TaskDCP* candidate, CPU cpu, bool push, const CandidateRelations& relations);
|
FindSlot findSlot(TaskDCP* candidate, CPU cpu, bool push, const CandidateRelations& relations);
|
||||||
@@ -128,7 +115,8 @@ private:
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
void runDcp();
|
void runDcp();
|
||||||
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatCompute> spatComputes, llvm::ArrayRef<IndexedEdge> edges)
|
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatCompute> spatComputes,
|
||||||
|
llvm::ArrayRef<IndexedEdge> edges)
|
||||||
: nodes(), cpuTasks(), cpuCrossbarUsage() {
|
: nodes(), cpuTasks(), cpuCrossbarUsage() {
|
||||||
for (auto spatCompute : spatComputes)
|
for (auto spatCompute : spatComputes)
|
||||||
nodes.emplace_back(spatCompute);
|
nodes.emplace_back(spatCompute);
|
||||||
@@ -138,18 +126,13 @@ public:
|
|||||||
|
|
||||||
GraphDCP(llvm::ArrayRef<Weight> nodeWeights,
|
GraphDCP(llvm::ArrayRef<Weight> nodeWeights,
|
||||||
llvm::ArrayRef<IndexedEdge> edges,
|
llvm::ArrayRef<IndexedEdge> edges,
|
||||||
llvm::ArrayRef<int64_t> nodeOrderKeys = {},
|
|
||||||
llvm::ArrayRef<CrossbarUsage> nodeCrossbarUsage = {})
|
llvm::ArrayRef<CrossbarUsage> nodeCrossbarUsage = {})
|
||||||
: nodes(), cpuTasks(), cpuCrossbarUsage() {
|
: nodes(), cpuTasks(), cpuCrossbarUsage() {
|
||||||
assert((nodeCrossbarUsage.empty() || nodeCrossbarUsage.size() == nodeWeights.size())
|
assert((nodeCrossbarUsage.empty() || nodeCrossbarUsage.size() == nodeWeights.size())
|
||||||
&& "synthetic crossbar usage must match synthetic node weights");
|
&& "synthetic crossbar usage must match synthetic node weights");
|
||||||
assert((nodeOrderKeys.empty() || nodeOrderKeys.size() == nodeWeights.size())
|
|
||||||
&& "synthetic node order keys must match synthetic node weights");
|
|
||||||
nodes.reserve(nodeWeights.size());
|
nodes.reserve(nodeWeights.size());
|
||||||
for (auto [index, weight] : llvm::enumerate(nodeWeights))
|
for (auto [index, weight] : llvm::enumerate(nodeWeights))
|
||||||
nodes.emplace_back(nodeOrderKeys.empty() ? static_cast<int64_t>(index) : nodeOrderKeys[index],
|
nodes.emplace_back(index, weight, nodeCrossbarUsage.empty() ? 0 : nodeCrossbarUsage[index]);
|
||||||
weight,
|
|
||||||
nodeCrossbarUsage.empty() ? 0 : nodeCrossbarUsage[index]);
|
|
||||||
for (auto [start, end, weight] : edges)
|
for (auto [start, end, weight] : edges)
|
||||||
makeEdge(start, end, weight);
|
makeEdge(start, end, weight);
|
||||||
}
|
}
|
||||||
@@ -167,11 +150,6 @@ public:
|
|||||||
void setMaxCpuCount(int value) { maxCpuCount = value; }
|
void setMaxCpuCount(int value) { maxCpuCount = value; }
|
||||||
int getMaxCpuCount() const { return maxCpuCount; }
|
int getMaxCpuCount() const { return maxCpuCount; }
|
||||||
|
|
||||||
// Total crossbar units allocated across all active CPUs.
|
|
||||||
size_t crossbarsUsed() const;
|
|
||||||
// Maximum crossbar units available across all active CPUs (lastCpu * per-CPU capacity).
|
|
||||||
size_t crossbarsAvailable() const;
|
|
||||||
|
|
||||||
// Optional MLIR context used to drive mlir::parallelFor inside runDcp. If
|
// Optional MLIR context used to drive mlir::parallelFor inside runDcp. If
|
||||||
// null the scheduler runs single-threaded (tests use this path).
|
// null the scheduler runs single-threaded (tests use this path).
|
||||||
void setContext(mlir::MLIRContext* ctx) { context = ctx; }
|
void setContext(mlir::MLIRContext* ctx) { context = ctx; }
|
||||||
|
|||||||
@@ -35,12 +35,10 @@ void DcpProgressLogger::recordSelectDuration(double seconds) { selectProcessorSe
|
|||||||
void DcpProgressLogger::recordUpdateDuration(double seconds) { updateTimingSeconds += seconds; }
|
void DcpProgressLogger::recordUpdateDuration(double seconds) { updateTimingSeconds += seconds; }
|
||||||
void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += taskCount; }
|
void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += taskCount; }
|
||||||
|
|
||||||
void DcpProgressLogger::printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const {
|
void DcpProgressLogger::printStart(size_t readyCount) const {
|
||||||
if (!logProgress)
|
if (!logProgress)
|
||||||
return;
|
return;
|
||||||
llvm::errs() << llvm::formatv(
|
llvm::errs() << llvm::formatv("[DCP] start: tasks={0} ready={1}\n", totalTasks, readyCount);
|
||||||
"[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n",
|
|
||||||
totalTasks, readyCount, maxCpuCount, xbarsCapacity);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
|
void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
|
||||||
@@ -50,15 +48,14 @@ void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
|
|||||||
if (!logProgress || elapsedSeconds < 1.0)
|
if (!logProgress || elapsedSeconds < 1.0)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
llvm::errs() << llvm::formatv("[DCP] slow node={0} elapsed={1} ready={2} cpus={3}\n",
|
llvm::errs() << llvm::formatv("[DCP] slow candidate node={0} elapsed={1} ready={2} cpus={3}\n",
|
||||||
nodeIndex,
|
nodeIndex,
|
||||||
formatDuration(elapsedSeconds),
|
formatDuration(elapsedSeconds),
|
||||||
readyCount,
|
readyCount,
|
||||||
cpuCount);
|
cpuCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DcpProgressLogger::printProgress(
|
void DcpProgressLogger::printProgress(size_t readyCount, CPU cpuCount, llvm::StringRef stage, bool force) {
|
||||||
size_t readyCount, CPU cpuCount, int maxCpuCount, size_t xbarsUsed, size_t xbarsAvailable, bool force) {
|
|
||||||
if (!logProgress)
|
if (!logProgress)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
@@ -71,19 +68,19 @@ void DcpProgressLogger::printProgress(
|
|||||||
double etaSeconds = rate > 0.0 ? static_cast<double>(totalTasks - completedTasks) / rate : 0.0;
|
double etaSeconds = rate > 0.0 ? static_cast<double>(totalTasks - completedTasks) / rate : 0.0;
|
||||||
double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks);
|
double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks);
|
||||||
|
|
||||||
bool done = completedTasks == totalTasks;
|
llvm::errs() << llvm::formatv("[DCP] {0}/{1} ({2:F1}%) ready={3} cpus={4} stage={5} elapsed={6} eta={7}\n",
|
||||||
llvm::errs() << llvm::formatv(
|
completedTasks,
|
||||||
"[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n",
|
totalTasks,
|
||||||
completedTasks,
|
percent,
|
||||||
totalTasks,
|
readyCount,
|
||||||
percent,
|
cpuCount,
|
||||||
readyCount,
|
stage,
|
||||||
cpuCount,
|
formatDuration(elapsedSeconds),
|
||||||
maxCpuCount,
|
completedTasks == totalTasks ? "0:00" : formatDuration(etaSeconds));
|
||||||
xbarsUsed,
|
llvm::errs() << llvm::formatv(" time(find={0}, select={1}, update={2})\n",
|
||||||
xbarsAvailable,
|
formatDuration(findCandidateSeconds),
|
||||||
llvm::formatv("elapsed={0}", formatDuration(elapsedSeconds)).str(),
|
formatDuration(selectProcessorSeconds),
|
||||||
done ? "" : llvm::formatv(" eta={0}", formatDuration(etaSeconds)).str());
|
formatDuration(updateTimingSeconds));
|
||||||
lastProgressPrint = now;
|
lastProgressPrint = now;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -94,9 +91,9 @@ void DcpProgressLogger::recordFindDuration(double) {}
|
|||||||
void DcpProgressLogger::recordSelectDuration(double) {}
|
void DcpProgressLogger::recordSelectDuration(double) {}
|
||||||
void DcpProgressLogger::recordUpdateDuration(double) {}
|
void DcpProgressLogger::recordUpdateDuration(double) {}
|
||||||
void DcpProgressLogger::advanceCompleted(size_t) {}
|
void DcpProgressLogger::advanceCompleted(size_t) {}
|
||||||
void DcpProgressLogger::printStart(size_t, int, size_t) const {}
|
void DcpProgressLogger::printStart(size_t) const {}
|
||||||
void DcpProgressLogger::maybePrintSlowCandidate(size_t, double, size_t, CPU) const {}
|
void DcpProgressLogger::maybePrintSlowCandidate(size_t, double, size_t, CPU) const {}
|
||||||
void DcpProgressLogger::printProgress(size_t, CPU, int, size_t, size_t, bool) {}
|
void DcpProgressLogger::printProgress(size_t, CPU, llvm::StringRef, bool) {}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|||||||
@@ -31,10 +31,9 @@ public:
|
|||||||
void recordUpdateDuration(double seconds);
|
void recordUpdateDuration(double seconds);
|
||||||
void advanceCompleted(size_t taskCount = 1);
|
void advanceCompleted(size_t taskCount = 1);
|
||||||
|
|
||||||
void printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const;
|
void printStart(size_t readyCount) const;
|
||||||
void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const;
|
void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const;
|
||||||
void printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount,
|
void printProgress(size_t readyCount, CPU cpuCount, llvm::StringRef stage, bool force);
|
||||||
size_t xbarsUsed, size_t xbarsAvailable, bool force);
|
|
||||||
|
|
||||||
#ifdef DCP_DEBUG_ENABLED
|
#ifdef DCP_DEBUG_ENABLED
|
||||||
private:
|
private:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -116,9 +116,10 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
|||||||
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
|
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
|
||||||
|
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPoint(coreOp);
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||||
|
|
||||||
rewriter.setInsertionPoint(mapOp);
|
rewriter.setInsertionPoint(mapOp);
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
|
||||||
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
|
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
|
||||||
pim::PimMemCopyOp::create(rewriter,
|
pim::PimMemCopyOp::create(rewriter,
|
||||||
mapOp.getLoc(),
|
mapOp.getLoc(),
|
||||||
@@ -257,18 +258,9 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
|||||||
if (!resultType || !resultType.hasStaticShape())
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Look through an optional pim.memcp_hd to find the source get_global.
|
|
||||||
// This occurs when the constant was staged into device memory before transposing.
|
|
||||||
pim::PimMemCopyHostToDevOp memcpHd;
|
|
||||||
auto sourceGetGlobal = transposeOp.getInput().getDefiningOp<memref::GetGlobalOp>();
|
auto sourceGetGlobal = transposeOp.getInput().getDefiningOp<memref::GetGlobalOp>();
|
||||||
if (!sourceGetGlobal) {
|
if (!sourceGetGlobal)
|
||||||
memcpHd = transposeOp.getInput().getDefiningOp<pim::PimMemCopyHostToDevOp>();
|
return failure();
|
||||||
if (!memcpHd)
|
|
||||||
return failure();
|
|
||||||
sourceGetGlobal = memcpHd.getHostSource().getDefiningOp<memref::GetGlobalOp>();
|
|
||||||
if (!sourceGetGlobal)
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
|
auto moduleOp = transposeOp->getParentOfType<ModuleOp>();
|
||||||
if (!moduleOp)
|
if (!moduleOp)
|
||||||
@@ -306,26 +298,13 @@ struct FoldConstantTransposePattern final : OpRewritePattern<pim::PimTransposeOp
|
|||||||
|
|
||||||
bool isAlwaysWeight =
|
bool isAlwaysWeight =
|
||||||
!transposeOp->getUsers().empty()
|
!transposeOp->getUsers().empty()
|
||||||
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) {
|
&& llvm::all_of(transposeOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
||||||
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
|
||||||
});
|
|
||||||
if (isAlwaysWeight) {
|
if (isAlwaysWeight) {
|
||||||
markWeightAlways(newGlobal);
|
markWeightAlways(newGlobal);
|
||||||
markWeightAlways(newGetGlobal);
|
markWeightAlways(newGetGlobal);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outputAllocOp = transposeOp.getOutputBuffer().getDefiningOp<memref::AllocOp>();
|
|
||||||
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
|
rewriter.replaceOp(transposeOp, newGetGlobal.getResult());
|
||||||
|
|
||||||
if (memcpHd && memcpHd.use_empty()) {
|
|
||||||
auto deviceAllocOp = memcpHd.getDeviceTarget().getDefiningOp<memref::AllocOp>();
|
|
||||||
rewriter.eraseOp(memcpHd);
|
|
||||||
if (deviceAllocOp && deviceAllocOp->use_empty())
|
|
||||||
rewriter.eraseOp(deviceAllocOp);
|
|
||||||
}
|
|
||||||
if (outputAllocOp && outputAllocOp->use_empty())
|
|
||||||
rewriter.eraseOp(outputAllocOp);
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -362,25 +341,18 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
|
if (!isa<pim::PimCoreOp>(user))
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
|
if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) {
|
||||||
return llvm::all_of(castOp->getUsers(), [](Operation* user) {
|
return llvm::all_of(castOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); });
|
||||||
return isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user);
|
|
||||||
});
|
|
||||||
})) {
|
})) {
|
||||||
allLiveUsersAreCoreOps = false;
|
allLiveUsersAreCoreOps = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
|
if (!llvm::all_of(allocOp->getUsers(), [](Operation* user) {
|
||||||
return isa<linalg::MapOp,
|
return isa<linalg::MapOp, memref::SubViewOp, memref::DeallocOp, memref::CastOp, pim::PimCoreOp>(user);
|
||||||
memref::SubViewOp,
|
|
||||||
memref::DeallocOp,
|
|
||||||
memref::CastOp,
|
|
||||||
pim::PimCoreOp,
|
|
||||||
pim::PimCoreBatchOp>(user);
|
|
||||||
})) {
|
})) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
@@ -417,83 +389,6 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FoldConstantHostCopyPattern final : OpRewritePattern<memref::CopyOp> {
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
|
|
||||||
LogicalResult matchAndRewrite(memref::CopyOp copyOp, PatternRewriter& rewriter) const override {
|
|
||||||
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto allocOp = copyOp.getTarget().getDefiningOp<memref::AllocOp>();
|
|
||||||
if (!allocOp)
|
|
||||||
return failure();
|
|
||||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
|
||||||
if (!allocType || !allocType.hasStaticShape())
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSource());
|
|
||||||
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource());
|
|
||||||
|
|
||||||
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
|
||||||
if (!moduleOp)
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
|
||||||
if (failed(denseAttr))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
DenseElementsAttr foldedAttr;
|
|
||||||
if (succeeded(srcSubview)) {
|
|
||||||
if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
|
||||||
return failure();
|
|
||||||
auto staticOffsets = getStaticSubviewOffsets(*srcSubview);
|
|
||||||
if (failed(staticOffsets))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape());
|
|
||||||
if (failed(maybeFoldedAttr))
|
|
||||||
return failure();
|
|
||||||
foldedAttr = *maybeFoldedAttr;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
|
||||||
if (resultTensorType != denseAttr->getType())
|
|
||||||
return failure();
|
|
||||||
foldedAttr = *denseAttr;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool allLiveUsersAreCores = true;
|
|
||||||
for (Operation* user : allocOp->getUsers()) {
|
|
||||||
if (user == copyOp)
|
|
||||||
continue;
|
|
||||||
if (isa<memref::DeallocOp>(user))
|
|
||||||
continue;
|
|
||||||
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
|
|
||||||
continue;
|
|
||||||
if (isa<memref::SubViewOp>(user)) {
|
|
||||||
allLiveUsersAreCores = false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_host_copy");
|
|
||||||
if (allLiveUsersAreCores)
|
|
||||||
markWeightAlways(newGlobal);
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(allocOp);
|
|
||||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
|
||||||
if (allLiveUsersAreCores)
|
|
||||||
markWeightAlways(newGetGlobal);
|
|
||||||
|
|
||||||
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
|
|
||||||
rewriter.eraseOp(copyOp);
|
|
||||||
if (allocOp.use_empty())
|
|
||||||
rewriter.eraseOp(allocOp);
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
@@ -548,7 +443,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|||||||
continue;
|
continue;
|
||||||
if (isa<memref::DeallocOp>(user))
|
if (isa<memref::DeallocOp>(user))
|
||||||
continue;
|
continue;
|
||||||
if (isa<pim::PimCoreOp, pim::PimCoreBatchOp>(user))
|
if (isa<pim::PimCoreOp>(user))
|
||||||
continue;
|
continue;
|
||||||
if (isa<memref::SubViewOp>(user)) {
|
if (isa<memref::SubViewOp>(user)) {
|
||||||
allLiveUsersAreCores = false;
|
allLiveUsersAreCores = false;
|
||||||
@@ -578,11 +473,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
|||||||
|
|
||||||
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
|
void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) {
|
||||||
patterns
|
patterns
|
||||||
.add<FoldConstantTransposePattern,
|
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, FoldConstantMemCpPattern>(
|
||||||
FoldConstantAllocPattern,
|
|
||||||
FoldConstantCoreMapPattern,
|
|
||||||
FoldConstantHostCopyPattern,
|
|
||||||
FoldConstantMemCpPattern>(
|
|
||||||
patterns.getContext());
|
patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,26 +24,7 @@ static bool isAddressOnlyHostOp(Operation* op) {
|
|||||||
memref::CastOp,
|
memref::CastOp,
|
||||||
memref::CollapseShapeOp,
|
memref::CollapseShapeOp,
|
||||||
memref::ExpandShapeOp,
|
memref::ExpandShapeOp,
|
||||||
memref::CopyOp>(op);
|
spatial::SpatChannelNewOp>(op);
|
||||||
}
|
|
||||||
|
|
||||||
// Looser than isCodegenAddressableValue: follows view ops without requiring contiguity.
|
|
||||||
// Used for memref.copy operands which may be non-contiguous subviews.
|
|
||||||
static bool isBaseAddressableValue(Value value) {
|
|
||||||
while (true) {
|
|
||||||
if (isa<BlockArgument>(value))
|
|
||||||
return true;
|
|
||||||
Operation* defOp = value.getDefiningOp();
|
|
||||||
if (!defOp)
|
|
||||||
return false;
|
|
||||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(defOp))
|
|
||||||
return true;
|
|
||||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) { value = subview.getSource(); continue; }
|
|
||||||
if (auto cast = dyn_cast<memref::CastOp>(defOp)) { value = cast.getSource(); continue; }
|
|
||||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) { value = collapse.getSrc(); continue; }
|
|
||||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) { value = expand.getSrc(); continue; }
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isCodegenAddressableValue(Value value) {
|
static bool isCodegenAddressableValue(Value value) {
|
||||||
@@ -57,8 +38,6 @@ static bool isCodegenAddressableValue(Value value) {
|
|||||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||||
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
if (isa<pim::PimMemCopyHostToDevOp>(op))
|
||||||
return operandIndex == 1;
|
return operandIndex == 1;
|
||||||
if (isa<pim::PimMemCopyHostToDevBatchOp>(op))
|
|
||||||
return operandIndex == 1;
|
|
||||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||||
return operandIndex == 0;
|
return operandIndex == 0;
|
||||||
return false;
|
return false;
|
||||||
@@ -90,12 +69,6 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
|
||||||
if (failed(verifyCoreWeights(moduleOp, coreBatchOp)) || failed(verifyCoreOperands(coreBatchOp)))
|
|
||||||
hasFailure = true;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
|
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
|
||||||
if (failed(verifyReturnOp(returnOp)))
|
if (failed(verifyReturnOp(returnOp)))
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
@@ -119,11 +92,10 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename CoreOpTy>
|
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, pim::PimCoreOp coreOp) {
|
||||||
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp) {
|
|
||||||
bool hasFailure = false;
|
bool hasFailure = false;
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
|
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||||
auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
|
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||||
if (!getGlobalOp) {
|
if (!getGlobalOp) {
|
||||||
coreOp.emitOpError() << "weight #" << weightIndex
|
coreOp.emitOpError() << "weight #" << weightIndex
|
||||||
<< " must be materialized as memref.get_global before JSON codegen";
|
<< " must be materialized as memref.get_global before JSON codegen";
|
||||||
@@ -159,8 +131,7 @@ private:
|
|||||||
return success(!hasFailure);
|
return success(!hasFailure);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename CoreOpTy>
|
static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) {
|
||||||
static LogicalResult verifyCoreOperands(CoreOpTy coreOp) {
|
|
||||||
return walkPimCoreBlock(
|
return walkPimCoreBlock(
|
||||||
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) {
|
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||||
bool hasFailure = false;
|
bool hasFailure = false;
|
||||||
@@ -203,13 +174,6 @@ private:
|
|||||||
return verifyAddressOnlySource(op, collapseOp.getSrc());
|
return verifyAddressOnlySource(op, collapseOp.getSrc());
|
||||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
||||||
return verifyAddressOnlySource(op, expandOp.getSrc());
|
return verifyAddressOnlySource(op, expandOp.getSrc());
|
||||||
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
|
|
||||||
if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) {
|
|
||||||
op->emitOpError("depends on a value that is not backed by addressable storage");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
add_custom_target(pim-unittest)
|
add_custom_target(pim-unittest)
|
||||||
set_target_properties(pim-unittest PROPERTIES FOLDER "Tests")
|
set_target_properties(pim-unittest PROPERTIES FOLDER "Tests")
|
||||||
|
|
||||||
|
|||||||
+21
-24
@@ -457,10 +457,6 @@ int testDCPGraphDiamondDependencies() {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// crossbarSize=4, crossbarCount=2 => capacity = 4*4*2 = 32.
|
|
||||||
// Each task with crossbarUsage=1 needs footprint = 4*4 = 16, so at most 1 task
|
|
||||||
// can fit per CPU (16+16 = 32 >= capacity). The scheduler must open a fresh CPU
|
|
||||||
// for each task; all three end up on separate CPUs with their base weight.
|
|
||||||
int testDCPGraphCrossbarExhaustion() {
|
int testDCPGraphCrossbarExhaustion() {
|
||||||
std::cout << "testDCPGraphCrossbarExhaustion:" << std::endl;
|
std::cout << "testDCPGraphCrossbarExhaustion:" << std::endl;
|
||||||
configureDcpDotOutput();
|
configureDcpDotOutput();
|
||||||
@@ -477,36 +473,37 @@ int testDCPGraphCrossbarExhaustion() {
|
|||||||
|
|
||||||
const std::vector<Weight> nodeWeights = {10, 10, 10};
|
const std::vector<Weight> nodeWeights = {10, 10, 10};
|
||||||
const std::vector<CrossbarUsage> nodeCrossbarUsage = {1, 1, 1};
|
const std::vector<CrossbarUsage> nodeCrossbarUsage = {1, 1, 1};
|
||||||
GraphDCP graph(nodeWeights, {}, {},nodeCrossbarUsage);
|
GraphDCP graph(nodeWeights, {}, nodeCrossbarUsage);
|
||||||
graph.setMaxCpuCount(3);
|
graph.setMaxCpuCount(1);
|
||||||
graph.runDcp();
|
graph.runDcp();
|
||||||
|
|
||||||
if (graph.cpuCount() != 3) {
|
if (graph.cpuCount() != 1) {
|
||||||
restoreCrossbarOptions();
|
restoreCrossbarOptions();
|
||||||
std::cerr << "Expected 3 CPUs (one per task due to crossbar limit), got " << graph.cpuCount() << "\n";
|
std::cerr << "Expected exactly 1 CPU with maxCpuCount=1, got " << graph.cpuCount() << "\n";
|
||||||
dumpDcpFailureArtifacts();
|
dumpDcpFailureArtifacts();
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int failures = 0;
|
auto scheduledTasks = graph.getScheduledTasks(0);
|
||||||
for (CPU c = 0; c < 3; c++) {
|
if (scheduledTasks.size() != 3) {
|
||||||
auto scheduledTasks = graph.getScheduledTasks(c);
|
restoreCrossbarOptions();
|
||||||
if (scheduledTasks.size() != 1) {
|
std::cerr << "Expected all three tasks to be scheduled on CPU 0\n";
|
||||||
std::cerr << "Expected exactly 1 task on CPU " << c << ", got " << scheduledTasks.size() << "\n";
|
printCpuSchedule(graph, 0);
|
||||||
printCpuSchedule(graph, c);
|
dumpDcpFailureArtifacts();
|
||||||
failures++;
|
return 1;
|
||||||
continue;
|
}
|
||||||
}
|
|
||||||
if (scheduledTasks[0].weight != 10) {
|
if (scheduledTasks[0].weight != 10 || scheduledTasks[1].weight != std::numeric_limits<Weight>::max()
|
||||||
std::cerr << "Expected weight=10 on CPU " << c << ", got " << scheduledTasks[0].weight << "\n";
|
|| scheduledTasks[2].weight != std::numeric_limits<Weight>::max()) {
|
||||||
printCpuSchedule(graph, c);
|
restoreCrossbarOptions();
|
||||||
failures++;
|
std::cerr << "Unexpected effective weights under crossbar exhaustion\n";
|
||||||
}
|
printCpuSchedule(graph, 0);
|
||||||
|
dumpDcpFailureArtifacts();
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
restoreCrossbarOptions();
|
restoreCrossbarOptions();
|
||||||
if (failures) dumpDcpFailureArtifacts();
|
return 0;
|
||||||
return failures;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class ValidationResult:
|
|||||||
|
|
||||||
|
|
||||||
class ProgressReporter:
|
class ProgressReporter:
|
||||||
def __init__(self, total_models, stages_per_model=STAGE_COUNT, enabled=None):
|
def __init__(self, total_models, stages_per_model=STAGE_COUNT):
|
||||||
self.total_models = total_models
|
self.total_models = total_models
|
||||||
self.stages_per_model = stages_per_model
|
self.stages_per_model = stages_per_model
|
||||||
self.total_steps = max(1, total_models * stages_per_model)
|
self.total_steps = max(1, total_models * stages_per_model)
|
||||||
@@ -45,7 +45,7 @@ class ProgressReporter:
|
|||||||
self.passed_models = 0
|
self.passed_models = 0
|
||||||
self.failed_models = 0
|
self.failed_models = 0
|
||||||
self.current_label = ""
|
self.current_label = ""
|
||||||
self.enabled = sys.stdout.isatty() if enabled is None else enabled
|
self.enabled = True
|
||||||
self.columns = shutil.get_terminal_size((100, 20)).columns
|
self.columns = shutil.get_terminal_size((100, 20)).columns
|
||||||
self.suspended = False
|
self.suspended = False
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user