Parallel bufferization
All checks were successful
Validate Operations / validate-operations (push) Successful in 21m49s
All checks were successful
Validate Operations / validate-operations (push) Successful in 21m49s
This commit is contained in:
154
README.md
154
README.md
@@ -1,5 +1,159 @@
|
|||||||
# Raptor
|
# Raptor
|
||||||
|
|
||||||
|
Raptor is a domain-specific MLIR compiler for neural networks (ONNX format)
|
||||||
|
targeting in-memory computing / processing-in-memory (PIM) architectures.
|
||||||
|
It progressively lowers ONNX-MLIR through a set of MLIR dialects down to
|
||||||
|
target-specific artifacts (currently JSON code for the `pimsim-nn` simulator).
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
PIM architectures perform most of the computation directly in memory.
|
||||||
|
Raptor's first supported target is `pimsim-nn`, which simulates a chip with:
|
||||||
|
- a shared host memory,
|
||||||
|
- a number of cores that do most of the computation directly in their memory
|
||||||
|
(vector ops, vmm/mvm on ReRAM crossbars),
|
||||||
|
- no branching instructions (branchless architecture) and no hardware loop
|
||||||
|
support — any repeated work (e.g. convolutions) must be unrolled into
|
||||||
|
explicit per-iteration instructions.
|
||||||
|
|
||||||
|
Because of this, the amount of emitted instructions explodes quickly and the
|
||||||
|
compiler must optimize aggressively at every stage to keep compilation
|
||||||
|
tractable.
|
||||||
|
|
||||||
|
A second target, `PulPim`, is planned for an accelerator with RISC-V cores
|
||||||
|
each carrying its own in-memory computing unit and crossbars. It will live in
|
||||||
|
a dedicated dialect (future work).
|
||||||
|
|
||||||
|
### Targets and simulators
|
||||||
|
|
||||||
|
`pimsim-nn` (under `backend-simulators/pim/pimsim-nn`) is used for
|
||||||
|
**performance** estimates (latency, energy), but does not functionally execute
|
||||||
|
the JSON code it consumes. To validate the numerical correctness of the JSON
|
||||||
|
code produced by Raptor (or, for comparison, by the `pimcomp` compiler), we use
|
||||||
|
a Rust simulator we maintain in-tree at
|
||||||
|
`backend-simulators/pim/pim-simulator`.
|
||||||
|
|
||||||
|
## Compilation pipeline
|
||||||
|
|
||||||
|
The PIM-related sources live under `src/PIM` and the tests under `test/PIM`.
|
||||||
|
When working on this codebase, most changes should stay confined to those
|
||||||
|
trees (you only need to look outside, e.g. at `onnx-mlir` or `llvm`, for
|
||||||
|
framework-level details).
|
||||||
|
|
||||||
|
High-level lowering flow:
|
||||||
|
|
||||||
|
```
|
||||||
|
ONNX-MLIR ──► Spatial ──► Pim (tensor) ──► Pim (bufferized) ──► PIM JSON
|
||||||
|
```
|
||||||
|
|
||||||
|
1. **ONNX → Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
|
||||||
|
Lowers ONNX ops into the `spat` dialect (`src/PIM/Dialect/Spatial`).
|
||||||
|
Spatial models a high-level spatial in-memory accelerator: vmm/mvm
|
||||||
|
operations are accelerated by storing a constant RHS matrix into a
|
||||||
|
crossbar. Crossbars cannot be re-programmed during execution, have a
|
||||||
|
limited fixed size, and there is a limited number of them per core.
|
||||||
|
Conversion patterns are split by op family under
|
||||||
|
`Conversion/ONNXToSpatial/Patterns/{Math,NN,Tensor}` (Conv, Gemm, MatMul,
|
||||||
|
Elementwise, ReduceMean, Pool, Relu, Sigmoid, Softmax, Concat, Gather,
|
||||||
|
Reshape, Resize, Split).
|
||||||
|
|
||||||
|
2. **Spatial → Pim** (`src/PIM/Conversion/SpatialToPim`).
|
||||||
|
Lowers Spatial to the `pim` dialect (`src/PIM/Dialect/Pim`), which
|
||||||
|
materializes PIM cores (`pim.core`), inter-core communication
|
||||||
|
(`pim.send` / `pim.receive`), halts, and crossbar-level operations.
|
||||||
|
|
||||||
|
3. **Merge compute nodes** (`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
|
||||||
|
A DCP-inspired heuristic (Dynamic Critical Path — see the original
|
||||||
|
scheduling paper by Kwok & Ahmad,
|
||||||
|
[DCP-eScience2007](https://clouds.cis.unimelb.edu.au/papers/DCP-eScience2007.pdf))
|
||||||
|
that coarsens the virtual node graph and decides how to group compute
|
||||||
|
nodes onto cores. Our implementation is only DCP-*inspired*: it is a
|
||||||
|
heuristic with different assumptions from the paper (different cost
|
||||||
|
model, constraints from crossbar capacity / core resources, and a
|
||||||
|
windowed coarsening loop instead of full-graph reprioritization). The
|
||||||
|
`dcp-critical-window-size` option controls how many lowest-slack virtual
|
||||||
|
nodes each coarsening iteration considers (0 = legacy full-graph
|
||||||
|
analysis). Related sources: `DCPGraph/DCPAnalysis.cpp`, `Graph.cpp/.hpp`,
|
||||||
|
`MergeComputeNodesPass.cpp`.
|
||||||
|
|
||||||
|
4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`).
|
||||||
|
Converts tensor-semantics PIM IR into memref-semantics PIM IR using the
|
||||||
|
standard MLIR `BufferizableOpInterface` machinery
|
||||||
|
(`OpBufferizationInterfaces.*`, `PimBufferization.td`).
|
||||||
|
|
||||||
|
5. **PIM code generation** (`src/PIM/Pass/PimCodegen`):
|
||||||
|
- `HostConstantFolding` — folds host-side constants.
|
||||||
|
- `MaterializeHostConstantsPass` — materializes the remaining host
|
||||||
|
constants for emission.
|
||||||
|
- `VerificationPass` — checks invariants before emission.
|
||||||
|
- `EmitPimJsonPass` — emits the final PIM JSON consumed by `pimsim-nn`
|
||||||
|
and `pim-simulator`.
|
||||||
|
|
||||||
|
Supporting pieces:
|
||||||
|
- `src/PIM/Compiler` — PIM-specific compiler options (crossbar size/count,
|
||||||
|
core count, DCP window, experimental conv impl, concat error handling, …)
|
||||||
|
and `PimCodeGen` entry points.
|
||||||
|
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`).
|
||||||
|
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`)
|
||||||
|
and the `PIMPasses.h` registry used by `PimAccelerator`.
|
||||||
|
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers
|
||||||
|
dialects, passes, and plugs Raptor into the ONNX-MLIR driver.
|
||||||
|
|
||||||
|
## Key compiler options
|
||||||
|
|
||||||
|
Pass these on the `onnx-mlir` command line when compiling for PIM:
|
||||||
|
|
||||||
|
- `--maccel=PIM` — select the PIM accelerator.
|
||||||
|
- `--EmitSpatial` / `--EmitPim` / `--EmitPimBufferized` / `--EmitPimCodegen`
|
||||||
|
— stop the pipeline at the requested stage (default: `EmitPimCodegen`).
|
||||||
|
- `--pim-only-codegen` — assume the input is already bufferized PIM IR and
|
||||||
|
run only the codegen tail.
|
||||||
|
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
|
||||||
|
per-core count.
|
||||||
|
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
|
||||||
|
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
|
||||||
|
- `--use-experimental-conv-impl` — alternative convolution lowering.
|
||||||
|
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
|
||||||
|
|
||||||
|
## Validation
|
||||||
|
|
||||||
|
Functional validation lives in `validation/` and drives the Rust
|
||||||
|
`pim-simulator` to compare Raptor's output against a reference.
|
||||||
|
|
||||||
|
Per-operation validation (from `validation/`):
|
||||||
|
|
||||||
|
```
|
||||||
|
validate.py \
|
||||||
|
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||||
|
--onnx-include-dir ../onnx-mlir/include
|
||||||
|
```
|
||||||
|
|
||||||
|
End-to-end network validation (example: first 4 layers of YOLOv11n):
|
||||||
|
|
||||||
|
```
|
||||||
|
validate.py \
|
||||||
|
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||||
|
--onnx-include-dir ../onnx-mlir/include \
|
||||||
|
--operations-dir ./networks/yolo11n/depth_04 \
|
||||||
|
--crossbar-size 2048
|
||||||
|
```
|
||||||
|
|
||||||
|
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
|
||||||
|
Available operations under `validation/operations/`: `add`, `conv`, `div`,
|
||||||
|
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
|
||||||
|
`sigmoid`, `softmax`, `split`.
|
||||||
|
|
||||||
|
## Rebuilding
|
||||||
|
|
||||||
|
Release build (fast):
|
||||||
|
|
||||||
|
```
|
||||||
|
cmake --build /home/nico/raptor/raptor/cmake-build-release --target onnx-mlir -j 30
|
||||||
|
```
|
||||||
|
|
||||||
|
A slower debug build is also available — configure it the same way but with
|
||||||
|
`-DCMAKE_BUILD_TYPE=Debug` (see installation instructions below).
|
||||||
|
|
||||||
## Build
|
## Build
|
||||||
|
|
||||||
### Protobuf
|
### Protobuf
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ llvm::cl::opt<size_t>
|
|||||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
|
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
|
||||||
|
|
||||||
llvm::cl::opt<size_t>
|
llvm::cl::opt<size_t>
|
||||||
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2));
|
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
|
||||||
|
|
||||||
llvm::cl::opt<long> coresCount("core-count",
|
llvm::cl::opt<long> coresCount("core-count",
|
||||||
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
||||||
@@ -51,7 +51,7 @@ llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
|||||||
"dcp-critical-window-size",
|
"dcp-critical-window-size",
|
||||||
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
||||||
"Use 0 to run the legacy full-graph DCP analysis."),
|
"Use 0 to run the legacy full-graph DCP analysis."),
|
||||||
llvm::cl::init(1024));
|
llvm::cl::init(4000));
|
||||||
|
|
||||||
llvm::cl::opt<bool>
|
llvm::cl::opt<bool>
|
||||||
ignoreConcatError("ignore-concat-error",
|
ignoreConcatError("ignore-concat-error",
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
@@ -8,10 +9,10 @@
|
|||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallSet.h"
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
@@ -24,8 +25,6 @@
|
|||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -52,7 +51,6 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
|||||||
private:
|
private:
|
||||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||||
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
|
|
||||||
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -156,8 +154,6 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
mergeTriviallyConnectedComputes(*entryFunc);
|
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
dumpModule(moduleOp, "spatial0");
|
dumpModule(moduleOp, "spatial0");
|
||||||
}
|
}
|
||||||
@@ -184,16 +180,6 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func
|
|||||||
|
|
||||||
bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) {
|
bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||||
if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present<tensor::ExtractSliceOp>(inst)) {
|
if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present<tensor::ExtractSliceOp>(inst)) {
|
||||||
for (auto& use : toRemoveOp->getUses()) {
|
|
||||||
auto users = use.getOwner();
|
|
||||||
if (auto spatCompUser = dyn_cast<spatial::SpatCompute>(users)) {
|
|
||||||
unsigned int poistionUses = use.getOperandNumber();
|
|
||||||
if (poistionUses < spatCompUser.getInputs().getBeginOperandIndex())
|
|
||||||
return false;
|
|
||||||
}else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto source = toRemoveOp.getSource();
|
auto source = toRemoveOp.getSource();
|
||||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||||
@@ -215,27 +201,24 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
|||||||
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
|
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
|
||||||
auto sources = toRemoveOp.getInputs();
|
auto sources = toRemoveOp.getInputs();
|
||||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||||
if (llvm::any_of(sources,
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
||||||
[](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
SmallVector<Type> sourceTypes;
|
||||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
SmallVector<Location> sourceLoc;
|
||||||
SmallVector<Type> sourceTypes;
|
for (auto source : sources) {
|
||||||
SmallVector<Location> sourceLoc;
|
sourceTypes.push_back(source.getType());
|
||||||
for (auto source : sources) {
|
sourceLoc.push_back(loc);
|
||||||
sourceTypes.push_back(source.getType());
|
|
||||||
sourceLoc.push_back(loc);
|
|
||||||
}
|
|
||||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
|
||||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
|
||||||
rewriter.setInsertionPointToEnd(BB);
|
|
||||||
IRMapping mapper;
|
|
||||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
|
||||||
mapper.map(source, bbArg);
|
|
||||||
auto newConcat = rewriter.clone(*inst, mapper);
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
|
||||||
inst->replaceAllUsesWith(newCompute->getResults());
|
|
||||||
inst->erase();
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
||||||
|
rewriter.setInsertionPointToEnd(BB);
|
||||||
|
IRMapping mapper;
|
||||||
|
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||||
|
mapper.map(source, bbArg);
|
||||||
|
auto newConcat = rewriter.clone(*inst, mapper);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
||||||
|
inst->replaceAllUsesWith(newCompute->getResults());
|
||||||
|
inst->erase();
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -297,6 +280,72 @@ static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewrite
|
|||||||
return cast<Value>(mapped);
|
return cast<Value>(mapped);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool sourceOpernadHasWeightAlways(Operation* op) {
|
||||||
|
if (op == nullptr)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
Operation* source = nullptr;
|
||||||
|
do {
|
||||||
|
|
||||||
|
if (isa<spatial::SpatCompute>(*op)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(*op)) {
|
||||||
|
auto tmpSource = extractSliceOp.getSource();
|
||||||
|
auto definingOp = tmpSource.getDefiningOp();
|
||||||
|
if (definingOp)
|
||||||
|
op = definingOp;
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(*op)) {
|
||||||
|
auto tmpSource = expandShapeOp.getSrc();
|
||||||
|
auto definingOp = tmpSource.getDefiningOp();
|
||||||
|
if (definingOp)
|
||||||
|
op = definingOp;
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (auto transposeOp = dyn_cast<ONNXTransposeOp>(*op)) {
|
||||||
|
auto tmpSource = transposeOp.getData();
|
||||||
|
auto definingOp = tmpSource.getDefiningOp();
|
||||||
|
if (definingOp)
|
||||||
|
op = definingOp;
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(*op)) {
|
||||||
|
auto tmpSource = collapseShapeOp.getSrc();
|
||||||
|
auto definingOp = tmpSource.getDefiningOp();
|
||||||
|
if (definingOp)
|
||||||
|
op = definingOp;
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (auto constantOp = dyn_cast<arith::ConstantOp>(*op)) {
|
||||||
|
source = constantOp;
|
||||||
|
}
|
||||||
|
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(*op)) {
|
||||||
|
bool res = false;
|
||||||
|
for (auto operand : concatOp.getOperands()) {
|
||||||
|
res |= hasWeightAlways(operand.getDefiningOp());
|
||||||
|
if (res)
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
op->dump();
|
||||||
|
llvm_unreachable("Global instruction not handle in func");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (source == nullptr);
|
||||||
|
|
||||||
|
if (hasWeightAlways(source))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO what we want to keep in global?
|
// TODO what we want to keep in global?
|
||||||
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
@@ -305,6 +354,11 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
|||||||
while (keep) {
|
while (keep) {
|
||||||
keep = false;
|
keep = false;
|
||||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
||||||
|
|
||||||
|
if (isa<spatial::SpatCompute>(instruction) || isa<func::ReturnOp>(instruction)
|
||||||
|
|| sourceOpernadHasWeightAlways(&instruction))
|
||||||
|
continue;
|
||||||
|
|
||||||
keep |= encapsulateSlice(rewriter, loc, &instruction);
|
keep |= encapsulateSlice(rewriter, loc, &instruction);
|
||||||
|
|
||||||
keep |= encapsulator<tensor::ExpandShapeOp>(
|
keep |= encapsulator<tensor::ExpandShapeOp>(
|
||||||
@@ -321,99 +375,6 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|
||||||
Location loc = funcOp.getLoc();
|
|
||||||
IRRewriter rewriter(&getContext());
|
|
||||||
SmallVector<spatial::SpatCompute> trivialComputes;
|
|
||||||
llvm::SmallSet<spatial::SpatCompute, 8> toErase;
|
|
||||||
|
|
||||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
|
|
||||||
if (compute->hasOneUse()) {
|
|
||||||
auto& use = *compute->getUses().begin();
|
|
||||||
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
|
||||||
|
|
||||||
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
|
||||||
trivialComputes.push_back(compute);
|
|
||||||
}
|
|
||||||
|
|
||||||
while (!trivialComputes.empty()) {
|
|
||||||
auto compute = trivialComputes.front();
|
|
||||||
|
|
||||||
if (compute.use_empty()) {
|
|
||||||
std::swap(trivialComputes.front(), trivialComputes.back());
|
|
||||||
trivialComputes.pop_back();
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
auto& computeUse = *compute->getUses().begin();
|
|
||||||
auto child = cast<spatial::SpatCompute>(computeUse.getOwner());
|
|
||||||
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
|
|
||||||
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
|
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
|
||||||
|
|
||||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
|
||||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
|
||||||
|
|
||||||
IRMapping mapper;
|
|
||||||
auto weightMutableIter = newCompute.getWeightsMutable();
|
|
||||||
for (auto weight : child.getWeights()) {
|
|
||||||
auto founded = llvm::find(newCompute.getWeights(), weight);
|
|
||||||
if (founded == newCompute.getWeights().end()) {
|
|
||||||
weightMutableIter.append(weight);
|
|
||||||
auto last = weightMutableIter.end();
|
|
||||||
last = std::prev(last, 1);
|
|
||||||
mapper.map(weight, last->get());
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
mapper.map(weight, *founded);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
|
||||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
|
||||||
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
|
|
||||||
newTerminator->erase();
|
|
||||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
|
||||||
for (auto& op : child.getBody().front()) {
|
|
||||||
auto newInst = rewriter.clone(op, mapper);
|
|
||||||
|
|
||||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
|
|
||||||
auto oldIndex = vmOp.getWeightIndex();
|
|
||||||
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
|
|
||||||
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
|
|
||||||
vmOp.setWeightIndex(newIndex);
|
|
||||||
}
|
|
||||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
|
|
||||||
auto oldIndex = vmOp.getWeightIndex();
|
|
||||||
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
|
|
||||||
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
|
|
||||||
vmOp.setWeightIndex(newIndex);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
child.replaceAllUsesWith(newCompute);
|
|
||||||
toErase.insert(child);
|
|
||||||
|
|
||||||
std::swap(trivialComputes.front(), trivialComputes.back());
|
|
||||||
trivialComputes.pop_back();
|
|
||||||
toErase.insert(compute);
|
|
||||||
|
|
||||||
if (newCompute->hasOneUse()) {
|
|
||||||
auto& use = *newCompute->getUses().begin();
|
|
||||||
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
|
||||||
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
|
||||||
trivialComputes.push_back(newCompute);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto compute : toErase) {
|
|
||||||
for (Value result : compute->getResults())
|
|
||||||
result.dropAllUses();
|
|
||||||
compute.erase();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||||
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
||||||
|
|||||||
@@ -96,8 +96,8 @@ bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
|
|||||||
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
|
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::Value createPimReceiveFromSpatialChannel(
|
mlir::Value
|
||||||
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
|
createPimReceiveFromSpatialChannel(PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
|
||||||
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
|
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
|
||||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
||||||
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
|
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
|
||||||
@@ -127,6 +127,16 @@ SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
|
|||||||
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
|
||||||
|
for (Operation* user : value.getUsers()) {
|
||||||
|
if (user->getBlock() != operation->getBlock())
|
||||||
|
return true;
|
||||||
|
if (operation->isBeforeInBlock(user))
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
||||||
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
||||||
mlir::Value result = operation->getResult(0);
|
mlir::Value result = operation->getResult(0);
|
||||||
@@ -134,8 +144,9 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter,
|
|||||||
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
||||||
|
|
||||||
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
|
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
|
||||||
auto validOperands =
|
auto validOperands = make_filter_range(operands, [operation, resultType](mlir::Value operand) {
|
||||||
make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; });
|
return operand.getType() == resultType && !hasLaterUserInBlock(operand, operation);
|
||||||
|
});
|
||||||
auto bestOperand = validOperands.begin();
|
auto bestOperand = validOperands.begin();
|
||||||
|
|
||||||
if (bestOperand != validOperands.end())
|
if (bestOperand != validOperands.end())
|
||||||
|
|||||||
@@ -2,10 +2,12 @@
|
|||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
@@ -21,6 +23,73 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
|
||||||
|
Location loc = extractSliceOp.getLoc();
|
||||||
|
|
||||||
|
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
for (auto& uses : extractSliceOp->getUses()) {
|
||||||
|
if (isa<spatial::SpatCompute>(uses.getOwner())) {
|
||||||
|
auto spatCompute = cast<spatial::SpatCompute>(uses.getOwner());
|
||||||
|
if (spatCompute.getInputs().empty())
|
||||||
|
return failure();
|
||||||
|
if (uses.getOperandNumber() < spatCompute.getInputs().getBeginOperandIndex())
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::DenseMap<spatial::SpatCompute, Value> mapSpatToExtract;
|
||||||
|
|
||||||
|
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
|
||||||
|
|
||||||
|
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
|
||||||
|
auto BBArgIndex = uses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||||
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||||
|
|
||||||
|
if (BBArgValue.use_empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
if (!mapSpatToExtract.contains(spatCompute)) {
|
||||||
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||||
|
mapSpatToExtract.insert({spatCompute, newExtractSlice->getResult(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
|
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatCompute]);
|
||||||
|
spatCompute.getInputsMutable().erase(BBArgIndex);
|
||||||
|
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
||||||
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
{
|
||||||
|
auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>();
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
if (!mapSpatToExtract.contains(spatCompute)) {
|
||||||
|
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
|
||||||
|
mapSpatToExtract.insert({spatCompute, newExtractSlice->getResult(0)});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
|
uses.set(mapSpatToExtract[spatCompute]);
|
||||||
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.eraseOp(extractSliceOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
|
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
@@ -34,6 +103,15 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
|||||||
if (!isa<func::FuncOp>(constantOp->getParentOp()))
|
if (!isa<func::FuncOp>(constantOp->getParentOp()))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
|
||||||
|
if (isa<spatial::SpatCompute>(op))
|
||||||
|
return false;
|
||||||
|
if (isa<func::FuncOp>(op->getParentOp()))
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}))
|
||||||
|
return failure();
|
||||||
|
|
||||||
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
|
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
|
||||||
|
|
||||||
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
|
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
|
||||||
@@ -51,7 +129,9 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
|||||||
rewriter.getUnitAttr(),
|
rewriter.getUnitAttr(),
|
||||||
{});
|
{});
|
||||||
|
|
||||||
for (auto& constUses : constantOp->getUses()) {
|
llvm::DenseMap<spatial::SpatCompute, Value> mapSpatComputeToConst;
|
||||||
|
|
||||||
|
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||||
auto constUsers = constUses.getOwner();
|
auto constUsers = constUses.getOwner();
|
||||||
|
|
||||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||||
@@ -59,23 +139,43 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
|
|||||||
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
if (!mapSpatComputeToConst.contains(spatCompute)) {
|
||||||
auto toTensor = bufferization::ToTensorOp::create(
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
mapSpatComputeToConst.insert({spatCompute, toTensor.getResult()});
|
||||||
|
}
|
||||||
|
|
||||||
rewriter.startOpModification(spatCompute.getOperation());
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
BBArgValue.replaceAllUsesWith(toTensor);
|
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatCompute]);
|
||||||
spatCompute.getInputsMutable().erase(BBArgIndex);
|
spatCompute.getInputsMutable().erase(BBArgIndex);
|
||||||
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
||||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
llvm_unreachable("Who are using const globally");
|
{
|
||||||
|
|
||||||
|
auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>();
|
||||||
|
if (!spatCompute)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
if (!mapSpatComputeToConst.contains(spatCompute)) {
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
mapSpatComputeToConst.insert({spatCompute, toTensor.getResult()});
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
|
constUses.set(mapSpatComputeToConst[spatCompute]);
|
||||||
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (constantOp.getType().isIntOrIndexOrFloat()) {
|
else if (constantOp.getType().isIntOrIndexOrFloat()) {
|
||||||
llvm::DenseMap<spatial::SpatCompute, Value> mapSpatComputeToConst;
|
llvm::DenseMap<spatial::SpatCompute, Value> mapSpatComputeToConst;
|
||||||
|
|
||||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||||
auto constUsers = constUses.getOwner();
|
auto constUsers = constUses.getOwner();
|
||||||
@@ -180,7 +280,8 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
|
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
|
||||||
patterns.add<FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(patterns.getContext());
|
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
||||||
|
patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
@@ -57,7 +58,7 @@ struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>>
|
|||||||
void runOnOperation() final;
|
void runOnOperation() final;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SmallVector<Value> outputTensors;
|
SmallVector<std::function<Value(IRRewriter& rewriter, Location loc)>> outputTensors;
|
||||||
size_t coreId = 0;
|
size_t coreId = 0;
|
||||||
SmallVector<Operation*> operationsToRemove;
|
SmallVector<Operation*> operationsToRemove;
|
||||||
|
|
||||||
@@ -293,7 +294,7 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
|||||||
auto storedType = cast<ShapedType>(storedValue.getType());
|
auto storedType = cast<ShapedType>(storedValue.getType());
|
||||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
Value outputTensor = outputTensors[resultIndexInReturn];
|
auto outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
||||||
if (auto storedOp = storedValue.getDefiningOp())
|
if (auto storedOp = storedValue.getDefiningOp())
|
||||||
rewriter.setInsertionPointAfter(storedOp);
|
rewriter.setInsertionPointAfter(storedOp);
|
||||||
PimMemCopyDevToHostOp::create(rewriter,
|
PimMemCopyDevToHostOp::create(rewriter,
|
||||||
@@ -315,8 +316,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
|||||||
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||||
|
|
||||||
// Store to global memory
|
// Store to global memory
|
||||||
Value outputTensor = outputTensors[resultIndexInReturn];
|
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
|
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
||||||
PimMemCopyDevToHostOp::create(rewriter,
|
PimMemCopyDevToHostOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
@@ -356,8 +357,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
|
|||||||
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
// Store to global memory
|
// Store to global memory
|
||||||
Value outputTensor = outputTensors[concatIndexInReturn];
|
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
|
Value outputTensor = outputTensors[concatIndexInReturn](rewriter, loc);
|
||||||
PimMemCopyDevToHostOp::create(rewriter,
|
PimMemCopyDevToHostOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
@@ -463,17 +464,35 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
|||||||
|
|
||||||
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||||
outputTensors.reserve(returnOp->getNumOperands());
|
outputTensors.reserve(returnOp->getNumOperands());
|
||||||
|
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
||||||
rewriter.setInsertionPointToStart(returnOp->getBlock());
|
rewriter.setInsertionPointToStart(returnOp->getBlock());
|
||||||
for (auto returnValue : returnOp->getOperands()) {
|
|
||||||
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
|
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
|
||||||
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
|
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
|
||||||
assert(!hasWeightAlways(returnValueDefiningOp));
|
assert(!hasWeightAlways(returnValueDefiningOp));
|
||||||
outputTensors.push_back(returnValue);
|
outputTensors.push_back( [returnValue] (IRRewriter& rewriter, Location loc) -> Value { return returnValue; });
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto newOutputTensor =
|
auto outRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(returnValue.getType());
|
||||||
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
|
mlir::MemRefType memRefType =
|
||||||
outputTensors.push_back(newOutputTensor);
|
mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
|
||||||
|
|
||||||
|
std::string outputName = "output_" + std::to_string(index);
|
||||||
|
rewriter.setInsertionPoint(returnOp.getParentOp());
|
||||||
|
memref::GlobalOp::create(rewriter,
|
||||||
|
returnOp.getLoc(),
|
||||||
|
rewriter.getStringAttr(outputName),
|
||||||
|
rewriter.getStringAttr("private"),
|
||||||
|
TypeAttr::get(memRefType),
|
||||||
|
{},
|
||||||
|
{},
|
||||||
|
{});
|
||||||
|
outputTensors.push_back(
|
||||||
|
[memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value {
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
return toTensor.getResult();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -502,7 +521,6 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
|||||||
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
for (auto& op : funcOp.getBody().getOps())
|
for (auto& op : funcOp.getBody().getOps())
|
||||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||||
assert(computeOp.getInputs().size() == 0 && "Already removed from mergeNode and global input handle");
|
assert(computeOp.getInputs().size() == 0 && "Already removed from mergeNode and global input handle");
|
||||||
@@ -694,12 +712,13 @@ void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter&
|
|||||||
|
|
||||||
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||||
|
auto loc = returnOp.getLoc();
|
||||||
for (auto it : llvm::enumerate(originalOperands)) {
|
for (auto it : llvm::enumerate(originalOperands)) {
|
||||||
size_t orderWithinReturn = it.index();
|
size_t orderWithinReturn = it.index();
|
||||||
Operation* returnOperand = it.value().getDefiningOp();
|
Operation* returnOperand = it.value().getDefiningOp();
|
||||||
|
rewriter.setInsertionPoint(returnOp);
|
||||||
rewriter.modifyOpInPlace(returnOp,
|
rewriter.modifyOpInPlace(returnOp,
|
||||||
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
|
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn](rewriter, loc)); });
|
||||||
|
|
||||||
Operation* opToErase = returnOperand;
|
Operation* opToErase = returnOperand;
|
||||||
while (opToErase) {
|
while (opToErase) {
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ def PimTensor :
|
|||||||
// Execution
|
// Execution
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def PimCoreOp : PimOp<"core", [SingleBlock]> {
|
def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
|
||||||
let summary = "Execute a block on a PIM core";
|
let summary = "Execute a block on a PIM core";
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$body);
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
|||||||
@@ -3,12 +3,17 @@
|
|||||||
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/Threading.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Compiler/PimCodeGen.hpp"
|
#include "Compiler/PimCodeGen.hpp"
|
||||||
#include "Dialect/Pim/PimOps.hpp"
|
#include "Dialect/Pim/PimOps.hpp"
|
||||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
@@ -40,24 +45,44 @@ private:
|
|||||||
|
|
||||||
void PimBufferizationPass::runOnOperation() {
|
void PimBufferizationPass::runOnOperation() {
|
||||||
auto moduleOp = getOperation();
|
auto moduleOp = getOperation();
|
||||||
|
// Refactor this into a function
|
||||||
|
{
|
||||||
|
auto funcOp = getPimEntryFunc(moduleOp);
|
||||||
|
|
||||||
// One-Shot-Bufferization
|
auto coreOps = llvm::to_vector(funcOp->getOps<pim::PimCoreOp>());
|
||||||
bufferization::OneShotBufferizationOptions options;
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
options.allowUnknownOps = true;
|
// failableParallelForEach will run the lambda in parallel and stop if any thread fails
|
||||||
bufferization::BufferizationState state;
|
LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](pim::PimCoreOp coreOp) {
|
||||||
|
// Again, allocate state LOCALLY per thread/function
|
||||||
|
bufferization::OneShotBufferizationOptions options;
|
||||||
|
options.allowUnknownOps = true;
|
||||||
|
bufferization::BufferizationState state;
|
||||||
|
if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) {
|
||||||
|
coreOp.emitError("Failed to bufferize PIM and Spatial ops");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
});
|
||||||
|
|
||||||
/*for (auto funcOp : moduleOp.getOps<func::FuncOp>()) {*/
|
if (failed(result)) {
|
||||||
/* for (auto pimCoreOp : funcOp.getOps<PimCoreOp>()) {*/
|
moduleOp.emitError("Failed to bufferize-parallel PIM and Spatial ops");
|
||||||
/* if (failed(bufferization::runOneShotBufferize(pimCoreOp, options, state))) {*/
|
signalPassFailure();
|
||||||
/* moduleOp.emitError("Failed to bufferize PIM and Spatial ops");*/
|
}
|
||||||
/* signalPassFailure();*/
|
|
||||||
/* }*/
|
|
||||||
/* }*/
|
|
||||||
/*}*/
|
|
||||||
|
|
||||||
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
|
funcOp->walk([&](bufferization::ToTensorOp toTensorOp) {
|
||||||
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
|
if (llvm::isa_and_present<pim::PimCoreOp>(toTensorOp->getParentOp()))
|
||||||
signalPassFailure();
|
toTensorOp->setAttr("restrict", UnitAttr::get(ctx));
|
||||||
|
});
|
||||||
|
|
||||||
|
// One-Shot-Bufferization
|
||||||
|
bufferization::OneShotBufferizationOptions options;
|
||||||
|
options.allowUnknownOps = true;
|
||||||
|
bufferization::BufferizationState state;
|
||||||
|
|
||||||
|
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
|
||||||
|
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MLIRContext* ctx = moduleOp.getContext();
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
|
|||||||
@@ -3,15 +3,17 @@
|
|||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/IR/ValueRange.h"
|
#include "mlir/IR/ValueRange.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <map>
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <set>
|
#include <queue>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -47,11 +49,13 @@ struct TimingInfo {
|
|||||||
|
|
||||||
struct WindowScheduleResult {
|
struct WindowScheduleResult {
|
||||||
std::vector<std::vector<size_t>> mergeGroups;
|
std::vector<std::vector<size_t>> mergeGroups;
|
||||||
bool usedAllAvailableCpus = false;
|
CPU cpuCount = 0;
|
||||||
|
size_t mergedNodeCount = 0;
|
||||||
|
size_t maxMergeGroupSize = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
||||||
std::map<std::pair<size_t, size_t>, Weight> edgeWeights;
|
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||||
for (auto [start, end, weight] : edges) {
|
for (auto [start, end, weight] : edges) {
|
||||||
size_t startIndex = static_cast<size_t>(start);
|
size_t startIndex = static_cast<size_t>(start);
|
||||||
size_t endIndex = static_cast<size_t>(end);
|
size_t endIndex = static_cast<size_t>(end);
|
||||||
@@ -59,11 +63,9 @@ std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
|||||||
continue;
|
continue;
|
||||||
auto key = std::make_pair(startIndex, endIndex);
|
auto key = std::make_pair(startIndex, endIndex);
|
||||||
Weight edgeWeight = static_cast<Weight>(weight);
|
Weight edgeWeight = static_cast<Weight>(weight);
|
||||||
auto it = edgeWeights.find(key);
|
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
|
||||||
if (it == edgeWeights.end())
|
if (!inserted.second)
|
||||||
edgeWeights.insert({key, edgeWeight});
|
inserted.first->second = std::max(inserted.first->second, edgeWeight);
|
||||||
else
|
|
||||||
it->second = std::max(it->second, edgeWeight);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<IndexedEdge> aggregatedEdges;
|
std::vector<IndexedEdge> aggregatedEdges;
|
||||||
@@ -71,6 +73,11 @@ std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
|||||||
for (auto [key, weight] : edgeWeights)
|
for (auto [key, weight] : edgeWeights)
|
||||||
aggregatedEdges.push_back(
|
aggregatedEdges.push_back(
|
||||||
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
||||||
|
llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) {
|
||||||
|
if (std::get<0>(lhs) != std::get<0>(rhs))
|
||||||
|
return std::get<0>(lhs) < std::get<0>(rhs);
|
||||||
|
return std::get<1>(lhs) < std::get<1>(rhs);
|
||||||
|
});
|
||||||
return aggregatedEdges;
|
return aggregatedEdges;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,10 +164,27 @@ TimingInfo computeTiming(const VirtualGraph& graph) {
|
|||||||
return timing;
|
return timing;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t> selectCriticalWindow(const TimingInfo& timing, size_t windowSize) {
|
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph& graph) {
|
||||||
std::vector<size_t> selected(timing.aest.size());
|
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
|
||||||
std::iota(selected.begin(), selected.end(), 0);
|
for (auto [start, end, weight] : graph.edges) {
|
||||||
std::stable_sort(selected.begin(), selected.end(), [&](size_t lhs, size_t rhs) {
|
(void) weight;
|
||||||
|
size_t startIndex = static_cast<size_t>(start);
|
||||||
|
size_t endIndex = static_cast<size_t>(end);
|
||||||
|
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
|
||||||
|
adjacency[startIndex].push_back(endIndex);
|
||||||
|
adjacency[endIndex].push_back(startIndex);
|
||||||
|
}
|
||||||
|
for (auto& neighbours : adjacency) {
|
||||||
|
llvm::sort(neighbours);
|
||||||
|
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
|
||||||
|
}
|
||||||
|
return adjacency;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) {
|
||||||
|
std::vector<size_t> ranked(timing.aest.size());
|
||||||
|
std::iota(ranked.begin(), ranked.end(), 0);
|
||||||
|
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
|
||||||
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
||||||
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
||||||
if (lhsSlack != rhsSlack)
|
if (lhsSlack != rhsSlack)
|
||||||
@@ -168,19 +192,83 @@ std::vector<size_t> selectCriticalWindow(const TimingInfo& timing, size_t window
|
|||||||
if (timing.aest[lhs] != timing.aest[rhs])
|
if (timing.aest[lhs] != timing.aest[rhs])
|
||||||
return timing.aest[lhs] < timing.aest[rhs];
|
return timing.aest[lhs] < timing.aest[rhs];
|
||||||
return lhs < rhs;
|
return lhs < rhs;
|
||||||
});
|
};
|
||||||
selected.resize(std::min(windowSize, selected.size()));
|
|
||||||
return selected;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes) {
|
windowSize = std::min(windowSize, ranked.size());
|
||||||
std::vector<size_t> signature;
|
if (windowSize == 0)
|
||||||
for (size_t nodeIndex : selectedNodes) {
|
return {};
|
||||||
const VirtualNode& node = graph.nodes[nodeIndex];
|
if (windowSize == ranked.size()) {
|
||||||
signature.insert(signature.end(), node.originalComputeIndices.begin(), node.originalComputeIndices.end());
|
llvm::sort(ranked, isHigherPriority);
|
||||||
|
return ranked;
|
||||||
}
|
}
|
||||||
std::sort(signature.begin(), signature.end());
|
|
||||||
return signature;
|
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
|
||||||
|
if (criticalPoolSize < ranked.size())
|
||||||
|
std::nth_element(
|
||||||
|
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
|
||||||
|
|
||||||
|
std::vector<char> inCriticalPool(ranked.size(), false);
|
||||||
|
for (size_t i = 0; i < criticalPoolSize; ++i)
|
||||||
|
inCriticalPool[ranked[i]] = true;
|
||||||
|
|
||||||
|
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
|
||||||
|
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
|
||||||
|
std::vector<size_t> selected;
|
||||||
|
std::vector<char> inWindow(ranked.size(), false);
|
||||||
|
selected.reserve(windowSize);
|
||||||
|
|
||||||
|
struct FrontierEntry {
|
||||||
|
size_t node;
|
||||||
|
};
|
||||||
|
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
|
||||||
|
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
|
||||||
|
|
||||||
|
auto addToWindow = [&](size_t node, const std::vector<char>& eligible) {
|
||||||
|
if (inWindow[node])
|
||||||
|
return;
|
||||||
|
inWindow[node] = true;
|
||||||
|
selected.push_back(node);
|
||||||
|
for (size_t neighbour : adjacency[node])
|
||||||
|
if (!inWindow[neighbour] && eligible[neighbour])
|
||||||
|
frontier.push({neighbour});
|
||||||
|
};
|
||||||
|
|
||||||
|
addToWindow(seed, inCriticalPool);
|
||||||
|
while (!frontier.empty() && selected.size() < windowSize) {
|
||||||
|
size_t node = frontier.top().node;
|
||||||
|
frontier.pop();
|
||||||
|
if (!inWindow[node])
|
||||||
|
addToWindow(node, inCriticalPool);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selected.size() < windowSize) {
|
||||||
|
std::vector<char> anyNode(ranked.size(), true);
|
||||||
|
for (size_t node : selected)
|
||||||
|
for (size_t neighbour : adjacency[node])
|
||||||
|
if (!inWindow[neighbour])
|
||||||
|
frontier.push({neighbour});
|
||||||
|
while (!frontier.empty() && selected.size() < windowSize) {
|
||||||
|
size_t node = frontier.top().node;
|
||||||
|
frontier.pop();
|
||||||
|
if (!inWindow[node])
|
||||||
|
addToWindow(node, anyNode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selected.size() < windowSize) {
|
||||||
|
llvm::sort(ranked, isHigherPriority);
|
||||||
|
for (size_t node : ranked) {
|
||||||
|
if (selected.size() == windowSize)
|
||||||
|
break;
|
||||||
|
if (!inWindow[node]) {
|
||||||
|
inWindow[node] = true;
|
||||||
|
selected.push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::sort(selected, isHigherPriority);
|
||||||
|
return selected;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
|
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
|
||||||
@@ -216,25 +304,47 @@ WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t>
|
|||||||
windowGraph.runDcp();
|
windowGraph.runDcp();
|
||||||
|
|
||||||
WindowScheduleResult result;
|
WindowScheduleResult result;
|
||||||
result.usedAllAvailableCpus = windowGraph.cpuCount() >= windowGraph.getMaxCpuCount();
|
result.cpuCount = windowGraph.cpuCount();
|
||||||
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
||||||
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
||||||
if (scheduledTasks.size() < 2)
|
if (scheduledTasks.size() < 2)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
result.mergedNodeCount += scheduledTasks.size();
|
||||||
|
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
|
||||||
std::vector<size_t> mergeGroup;
|
std::vector<size_t> mergeGroup;
|
||||||
mergeGroup.reserve(scheduledTasks.size());
|
mergeGroup.reserve(scheduledTasks.size());
|
||||||
for (const auto& task : scheduledTasks)
|
for (const auto& task : scheduledTasks)
|
||||||
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
||||||
std::sort(mergeGroup.begin(), mergeGroup.end());
|
|
||||||
result.mergeGroups.push_back(std::move(mergeGroup));
|
result.mergeGroups.push_back(std::move(mergeGroup));
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool coarsenGraph(const VirtualGraph& graph, ArrayRef<std::vector<size_t>> mergeGroups, VirtualGraph& coarsenedGraph) {
|
bool coarsenGraph(const VirtualGraph& graph,
|
||||||
|
ArrayRef<std::vector<size_t>> mergeGroups,
|
||||||
|
VirtualGraph& coarsenedGraph,
|
||||||
|
std::vector<size_t>& oldToNewNode) {
|
||||||
|
TimingInfo timing = computeTiming(graph);
|
||||||
|
std::vector<size_t> topologicalRank(graph.nodes.size());
|
||||||
|
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
|
||||||
|
if (timing.valid)
|
||||||
|
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
|
||||||
|
topologicalRank[nodeIndex] = rank;
|
||||||
|
|
||||||
|
std::vector<std::vector<size_t>> orderedMergeGroups;
|
||||||
|
orderedMergeGroups.reserve(mergeGroups.size());
|
||||||
|
for (const auto& mergeGroup : mergeGroups) {
|
||||||
|
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
|
||||||
|
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
|
||||||
|
if (topologicalRank[lhs] != topologicalRank[rhs])
|
||||||
|
return topologicalRank[lhs] < topologicalRank[rhs];
|
||||||
|
return lhs < rhs;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
||||||
for (auto [groupIndex, mergeGroup] : llvm::enumerate(mergeGroups)) {
|
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
|
||||||
if (mergeGroup.size() < 2)
|
if (mergeGroup.size() < 2)
|
||||||
continue;
|
continue;
|
||||||
for (size_t nodeIndex : mergeGroup) {
|
for (size_t nodeIndex : mergeGroup) {
|
||||||
@@ -243,18 +353,21 @@ bool coarsenGraph(const VirtualGraph& graph, ArrayRef<std::vector<size_t>> merge
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::optional<size_t>> mergeGroupToNewNode(mergeGroups.size());
|
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
|
||||||
std::vector<size_t> oldToNewNode(graph.nodes.size(), 0);
|
std::vector<size_t> newNodeRank;
|
||||||
|
oldToNewNode.assign(graph.nodes.size(), 0);
|
||||||
bool mergedAny = false;
|
bool mergedAny = false;
|
||||||
coarsenedGraph.nodes.clear();
|
coarsenedGraph.nodes.clear();
|
||||||
coarsenedGraph.edges.clear();
|
coarsenedGraph.edges.clear();
|
||||||
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
||||||
|
newNodeRank.reserve(graph.nodes.size());
|
||||||
|
|
||||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||||
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
||||||
if (mergeGroupIndex == -1) {
|
if (mergeGroupIndex == -1) {
|
||||||
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
||||||
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
||||||
|
newNodeRank.push_back(topologicalRank[nodeIndex]);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -265,7 +378,7 @@ bool coarsenGraph(const VirtualGraph& graph, ArrayRef<std::vector<size_t>> merge
|
|||||||
}
|
}
|
||||||
|
|
||||||
VirtualNode mergedNode;
|
VirtualNode mergedNode;
|
||||||
for (size_t memberIndex : mergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||||
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
||||||
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
|
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
|
||||||
memberNode.originalComputeIndices.end());
|
memberNode.originalComputeIndices.end());
|
||||||
@@ -276,8 +389,9 @@ bool coarsenGraph(const VirtualGraph& graph, ArrayRef<std::vector<size_t>> merge
|
|||||||
|
|
||||||
mergedAny = true;
|
mergedAny = true;
|
||||||
newNodeIndex = coarsenedGraph.nodes.size();
|
newNodeIndex = coarsenedGraph.nodes.size();
|
||||||
for (size_t memberIndex : mergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
||||||
oldToNewNode[memberIndex] = *newNodeIndex;
|
oldToNewNode[memberIndex] = *newNodeIndex;
|
||||||
|
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
|
||||||
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -291,75 +405,61 @@ bool coarsenGraph(const VirtualGraph& graph, ArrayRef<std::vector<size_t>> merge
|
|||||||
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
||||||
if (newStart == newEnd)
|
if (newStart == newEnd)
|
||||||
continue;
|
continue;
|
||||||
|
if (newNodeRank[newStart] >= newNodeRank[newEnd])
|
||||||
|
continue;
|
||||||
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
||||||
}
|
}
|
||||||
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
||||||
|
|
||||||
return computeTiming(coarsenedGraph).valid;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool coarsenGraphWithFallback(const VirtualGraph& graph,
|
constexpr CPU kDefaultMaxCpuCount = 1000;
|
||||||
ArrayRef<std::vector<size_t>> mergeGroups,
|
|
||||||
VirtualGraph& coarsenedGraph) {
|
|
||||||
if (coarsenGraph(graph, mergeGroups, coarsenedGraph))
|
|
||||||
return true;
|
|
||||||
|
|
||||||
std::vector<size_t> orderedGroupIndices(mergeGroups.size());
|
CPU getVirtualGraphMaxCpuCount() {
|
||||||
std::iota(orderedGroupIndices.begin(), orderedGroupIndices.end(), 0);
|
if (coresCount.getValue() > 0)
|
||||||
std::stable_sort(orderedGroupIndices.begin(), orderedGroupIndices.end(), [&](size_t lhs, size_t rhs) {
|
return static_cast<CPU>(coresCount.getValue());
|
||||||
return mergeGroups[lhs].size() > mergeGroups[rhs].size();
|
return kDefaultMaxCpuCount;
|
||||||
});
|
|
||||||
|
|
||||||
std::vector<std::vector<size_t>> acceptedMergeGroups;
|
|
||||||
acceptedMergeGroups.reserve(mergeGroups.size());
|
|
||||||
for (size_t groupIndex : orderedGroupIndices) {
|
|
||||||
std::vector<std::vector<size_t>> candidateMergeGroups = acceptedMergeGroups;
|
|
||||||
candidateMergeGroups.push_back(mergeGroups[groupIndex]);
|
|
||||||
|
|
||||||
VirtualGraph candidateGraph;
|
|
||||||
if (!coarsenGraph(graph, candidateMergeGroups, candidateGraph))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
acceptedMergeGroups = std::move(candidateMergeGroups);
|
|
||||||
coarsenedGraph = std::move(candidateGraph);
|
|
||||||
}
|
|
||||||
return !acceptedMergeGroups.empty();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, ArrayRef<IndexedEdge> edges) {
|
size_t getDcpCoarseningWindowSize(size_t nodeCount) {
|
||||||
VirtualGraph graph;
|
size_t windowSize = std::min(dcpCriticalWindowSize.getValue(), nodeCount);
|
||||||
graph.nodes.resize(computeCount);
|
CPU maxCpuCount = std::max<CPU>(1, getVirtualGraphMaxCpuCount());
|
||||||
graph.edges = aggregateEdges(edges);
|
if (nodeCount > static_cast<size_t>(maxCpuCount))
|
||||||
TimingInfo timing = computeTiming(graph);
|
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
|
||||||
if (timing.valid)
|
return windowSize;
|
||||||
return timing.topologicalOrder;
|
|
||||||
|
|
||||||
std::vector<size_t> fallbackOrder(computeCount);
|
|
||||||
std::iota(fallbackOrder.begin(), fallbackOrder.end(), 0);
|
|
||||||
return fallbackOrder;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
|
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef<SpatCompute> spatComputes) {
|
||||||
ArrayRef<SpatCompute> spatComputes,
|
|
||||||
ArrayRef<IndexedEdge> originalEdges) {
|
|
||||||
DCPAnalysisResult result;
|
DCPAnalysisResult result;
|
||||||
std::vector<size_t> originalToVirtualNode(spatComputes.size(), 0);
|
|
||||||
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
|
|
||||||
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
|
||||||
originalToVirtualNode[originalIndex] = virtualNodeIndex;
|
|
||||||
|
|
||||||
auto dominanceOrder = computeOriginalTopologicalOrder(spatComputes.size(), originalEdges);
|
TimingInfo timing = computeTiming(graph);
|
||||||
result.dominanceOrderCompute.reserve(dominanceOrder.size());
|
std::vector<size_t> virtualNodeOrder;
|
||||||
for (size_t originalIndex : dominanceOrder) {
|
if (timing.valid) {
|
||||||
SpatCompute spatCompute = spatComputes[originalIndex];
|
virtualNodeOrder = std::move(timing.topologicalOrder);
|
||||||
size_t cpu = originalToVirtualNode[originalIndex];
|
}
|
||||||
|
else {
|
||||||
|
virtualNodeOrder.resize(graph.nodes.size());
|
||||||
|
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> originalComputeToCpu(spatComputes.size(), 0);
|
||||||
|
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
|
||||||
|
const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
|
||||||
|
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
||||||
|
originalComputeToCpu[originalIndex] = cpu;
|
||||||
|
}
|
||||||
|
|
||||||
|
result.dominanceOrderCompute.reserve(spatComputes.size());
|
||||||
|
for (auto [originalIndex, spatCompute] : llvm::enumerate(spatComputes)) {
|
||||||
|
size_t cpu = originalComputeToCpu[originalIndex];
|
||||||
result.dominanceOrderCompute.push_back(spatCompute);
|
result.dominanceOrderCompute.push_back(spatCompute);
|
||||||
result.computeToCpuMap[spatCompute] = cpu;
|
result.computeToCpuMap[spatCompute] = cpu;
|
||||||
result.cpuToLastComputeMap[cpu] = spatCompute;
|
result.cpuToLastComputeMap[cpu] = spatCompute;
|
||||||
}
|
}
|
||||||
|
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||||
for (auto [cpu, lastCompute] : result.cpuToLastComputeMap)
|
|
||||||
result.isLastComputeOfCpu.insert(lastCompute);
|
result.isLastComputeOfCpu.insert(lastCompute);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -409,32 +509,74 @@ DCPAnalysisResult DCPAnalysis::run() {
|
|||||||
return runLegacyDcp(spatComputes, edges, entryOp->getContext());
|
return runLegacyDcp(spatComputes, edges, entryOp->getContext());
|
||||||
|
|
||||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatComputes, edges);
|
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatComputes, edges);
|
||||||
std::set<std::vector<size_t>> seenCriticalWindows;
|
size_t iteration = 0;
|
||||||
while (virtualGraph.nodes.size() > 1) {
|
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
|
||||||
TimingInfo timing = computeTiming(virtualGraph);
|
size_t oldNodeCount = virtualGraph.nodes.size();
|
||||||
if (!timing.valid)
|
|
||||||
break;
|
|
||||||
|
|
||||||
auto selectedNodes = selectCriticalWindow(timing, dcpCriticalWindowSize.getValue());
|
|
||||||
if (selectedNodes.size() < 2)
|
|
||||||
break;
|
|
||||||
|
|
||||||
if (!seenCriticalWindows.insert(getOriginalSignature(virtualGraph, selectedNodes)).second)
|
|
||||||
break;
|
|
||||||
|
|
||||||
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
|
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
|
||||||
if (windowSchedule.mergeGroups.empty())
|
if (windowSchedule.mergeGroups.empty()) {
|
||||||
break;
|
if (oldNodeCount >= 200)
|
||||||
|
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||||
|
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
|
||||||
|
iteration,
|
||||||
|
oldNodeCount,
|
||||||
|
selectedNodes.size(),
|
||||||
|
windowSchedule.cpuCount);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
VirtualGraph coarsenedGraph;
|
VirtualGraph coarsenedGraph;
|
||||||
if (!coarsenGraphWithFallback(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph))
|
std::vector<size_t> oldToNewNode;
|
||||||
break;
|
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
|
||||||
|
return false;
|
||||||
|
if (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200)
|
||||||
|
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||||
|
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
|
||||||
|
iteration,
|
||||||
|
oldNodeCount,
|
||||||
|
selectedNodes.size(),
|
||||||
|
windowSchedule.cpuCount,
|
||||||
|
windowSchedule.mergeGroups.size(),
|
||||||
|
windowSchedule.mergedNodeCount,
|
||||||
|
windowSchedule.maxMergeGroupSize,
|
||||||
|
coarsenedGraph.nodes.size(),
|
||||||
|
oldNodeCount - coarsenedGraph.nodes.size());
|
||||||
virtualGraph = std::move(coarsenedGraph);
|
virtualGraph = std::move(coarsenedGraph);
|
||||||
if (windowSchedule.usedAllAvailableCpus)
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
while (virtualGraph.nodes.size() > 1) {
|
||||||
|
iteration++;
|
||||||
|
TimingInfo timing = computeTiming(virtualGraph);
|
||||||
|
if (!timing.valid) {
|
||||||
|
if (virtualGraph.nodes.size() >= 200)
|
||||||
|
llvm::errs() << llvm::formatv(
|
||||||
|
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<size_t> selectedNodes;
|
||||||
|
auto criticalWindow =
|
||||||
|
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size()));
|
||||||
|
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
|
||||||
|
|
||||||
|
if (selectedNodes.size() < 2) {
|
||||||
|
if (virtualGraph.nodes.size() >= 200)
|
||||||
|
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
|
||||||
|
iteration,
|
||||||
|
virtualGraph.nodes.size(),
|
||||||
|
selectedNodes.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tryCoarsenSelectedNodes(selectedNodes))
|
||||||
|
continue;
|
||||||
|
if (virtualGraph.nodes.size() >= 200)
|
||||||
|
llvm::errs() << llvm::formatv(
|
||||||
|
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
return buildResultFromVirtualGraph(virtualGraph, spatComputes, edges);
|
return buildResultFromVirtualGraph(virtualGraph, spatComputes);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
@@ -38,11 +38,14 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
#include <cstdint>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
#include <queue>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "DCPAnalysis.hpp"
|
#include "DCPAnalysis.hpp"
|
||||||
@@ -60,6 +63,7 @@ namespace {
|
|||||||
// Coarse-grained phase timers printed when DCP_SELECT_PROFILE is set.
|
// Coarse-grained phase timers printed when DCP_SELECT_PROFILE is set.
|
||||||
struct SelectTimers {
|
struct SelectTimers {
|
||||||
double findSlot = 0.0;
|
double findSlot = 0.0;
|
||||||
|
double dedup = 0.0;
|
||||||
double precheck = 0.0;
|
double precheck = 0.0;
|
||||||
double snapshotInsertUpdate = 0.0;
|
double snapshotInsertUpdate = 0.0;
|
||||||
double childSlot = 0.0;
|
double childSlot = 0.0;
|
||||||
@@ -70,9 +74,19 @@ struct SelectTimers {
|
|||||||
long tasksProcessed = 0;
|
long tasksProcessed = 0;
|
||||||
void dump(const char* label) const {
|
void dump(const char* label) const {
|
||||||
std::fprintf(stderr,
|
std::fprintf(stderr,
|
||||||
"[selectProfile:%s] tasks=%ld findSlot=%.2fs precheck=%.2fs snapUpd=%.2fs childSlot=%.2fs rollback=%.2fs iter=%ld precheckPass=%ld dcplPass=%ld\n",
|
"[selectProfile:%s] tasks=%ld dedup=%.2fs findSlot=%.2fs precheck=%.2fs snapUpd=%.2fs "
|
||||||
label, tasksProcessed, findSlot, precheck, snapshotInsertUpdate, childSlot,
|
"childSlot=%.2fs rollback=%.2fs iter=%ld precheckPass=%ld dcplPass=%ld\n",
|
||||||
rollbackRestore, iterations, passedPrecheck, passedDcpl);
|
label,
|
||||||
|
tasksProcessed,
|
||||||
|
dedup,
|
||||||
|
findSlot,
|
||||||
|
precheck,
|
||||||
|
snapshotInsertUpdate,
|
||||||
|
childSlot,
|
||||||
|
rollbackRestore,
|
||||||
|
iterations,
|
||||||
|
passedPrecheck,
|
||||||
|
passedDcpl);
|
||||||
}
|
}
|
||||||
~SelectTimers() {
|
~SelectTimers() {
|
||||||
if (std::getenv("DCP_SELECT_PROFILE"))
|
if (std::getenv("DCP_SELECT_PROFILE"))
|
||||||
@@ -83,6 +97,101 @@ static SelectTimers gSelectTimers;
|
|||||||
} // namespace
|
} // namespace
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
uint64_t mixHash(uint64_t seed, uint64_t value) {
|
||||||
|
seed ^= value + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t finishHash(uint64_t seed) {
|
||||||
|
seed ^= seed >> 33;
|
||||||
|
seed *= 0xff51afd7ed558ccdULL;
|
||||||
|
seed ^= seed >> 33;
|
||||||
|
seed *= 0xc4ceb9fe1a85ec53ULL;
|
||||||
|
seed ^= seed >> 33;
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t hashEdgeSignature(uint64_t neighborHash, Weight weight, uint64_t direction) {
|
||||||
|
uint64_t hash = mixHash(0x84222325cbf29ce4ULL, direction);
|
||||||
|
hash = mixHash(hash, neighborHash);
|
||||||
|
hash = mixHash(hash, static_cast<uint64_t>(weight));
|
||||||
|
return finishHash(hash);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CpuAestCache {
|
||||||
|
Time defaultAest = 0;
|
||||||
|
llvm::SmallDenseMap<CPU, Time, 8> colocatedParentAests;
|
||||||
|
|
||||||
|
Time get(CPU cpu) const {
|
||||||
|
auto it = colocatedParentAests.find(cpu);
|
||||||
|
if (it == colocatedParentAests.end())
|
||||||
|
return defaultAest;
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CpuTimeMax {
|
||||||
|
CPU cpu = -1;
|
||||||
|
Time time = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
void updateCpuTimeMax(CpuTimeMax& first, CpuTimeMax& second, CPU cpu, Time time) {
|
||||||
|
if (first.cpu == cpu) {
|
||||||
|
first.time = std::max(first.time, time);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (second.cpu == cpu) {
|
||||||
|
second.time = std::max(second.time, time);
|
||||||
|
if (second.time > first.time)
|
||||||
|
std::swap(first, second);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (time >= first.time) {
|
||||||
|
second = first;
|
||||||
|
first = {cpu, time};
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (time > second.time)
|
||||||
|
second = {cpu, time};
|
||||||
|
}
|
||||||
|
|
||||||
|
CpuAestCache computeCpuAestCache(TaskDCP* task) {
|
||||||
|
CpuAestCache cache;
|
||||||
|
llvm::SmallDenseMap<CPU, Time, 8> transferAestByCpu;
|
||||||
|
llvm::SmallDenseMap<CPU, Time, 8> localAestByCpu;
|
||||||
|
Time unscheduledTransferAest = 0;
|
||||||
|
|
||||||
|
for (const Edge& parentEdge : task->parents) {
|
||||||
|
Time parentFinish = addOrMax(parentEdge.first->getAest(), parentEdge.first->getWeight());
|
||||||
|
Time transferAest = addOrMax(parentFinish, getTransferCost(parentEdge.first, task));
|
||||||
|
if (std::optional<CPU> parentCpu = parentEdge.first->getCpu()) {
|
||||||
|
Time& cpuTransferAest = transferAestByCpu[*parentCpu];
|
||||||
|
cpuTransferAest = std::max(cpuTransferAest, transferAest);
|
||||||
|
Time& cpuLocalAest = localAestByCpu[*parentCpu];
|
||||||
|
cpuLocalAest = std::max(cpuLocalAest, parentFinish);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
unscheduledTransferAest = std::max(unscheduledTransferAest, transferAest);
|
||||||
|
}
|
||||||
|
|
||||||
|
CpuTimeMax firstOther {-1, unscheduledTransferAest};
|
||||||
|
CpuTimeMax secondOther {-1, 0};
|
||||||
|
for (const auto& entry : transferAestByCpu)
|
||||||
|
updateCpuTimeMax(firstOther, secondOther, entry.first, entry.second);
|
||||||
|
|
||||||
|
cache.defaultAest = firstOther.time;
|
||||||
|
for (const auto& entry : localAestByCpu) {
|
||||||
|
CPU cpu = entry.first;
|
||||||
|
Time bestNonLocalParentAest = firstOther.cpu == cpu ? secondOther.time : firstOther.time;
|
||||||
|
cache.colocatedParentAests[cpu] = std::max(bestNonLocalParentAest, entry.second);
|
||||||
|
}
|
||||||
|
return cache;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Edge manipulation
|
// Edge manipulation
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -156,6 +265,49 @@ std::vector<TaskDCP*> GraphDCP::getRoots() {
|
|||||||
return tmp;
|
return tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GraphDCP::initTaskStructureHashes() {
|
||||||
|
taskStructureHashes.resize(nodes.size());
|
||||||
|
for (auto [index, task] : llvm::enumerate(nodes)) {
|
||||||
|
uint64_t hash = mixHash(0x7442b1129fd01363ULL, static_cast<uint64_t>(task.getWeight()));
|
||||||
|
hash = mixHash(hash, static_cast<uint64_t>(task.getCrossbarUsage()));
|
||||||
|
taskStructureHashes[index] = finishHash(hash);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<uint64_t> nextHashes(nodes.size());
|
||||||
|
std::vector<uint64_t> edgeHashes;
|
||||||
|
for (int iteration = 0; iteration < 4; ++iteration) {
|
||||||
|
for (auto [index, task] : llvm::enumerate(nodes)) {
|
||||||
|
uint64_t hash = mixHash(0x464dcab27ac82291ULL, taskStructureHashes[index]);
|
||||||
|
edgeHashes.clear();
|
||||||
|
edgeHashes.reserve(task.parents.size() + task.children.size());
|
||||||
|
for (const Edge& parent : task.parents)
|
||||||
|
if (!parent.isScheduling)
|
||||||
|
edgeHashes.push_back(
|
||||||
|
hashEdgeSignature(taskStructureHashes[getNodeIndex(parent.first)], parent.second, /*direction=*/0));
|
||||||
|
for (const Edge& child : task.children)
|
||||||
|
if (!child.isScheduling)
|
||||||
|
edgeHashes.push_back(
|
||||||
|
hashEdgeSignature(taskStructureHashes[getNodeIndex(child.first)], child.second, /*direction=*/1));
|
||||||
|
llvm::sort(edgeHashes);
|
||||||
|
hash = mixHash(hash, static_cast<uint64_t>(edgeHashes.size()));
|
||||||
|
for (uint64_t edgeHash : edgeHashes)
|
||||||
|
hash = mixHash(hash, edgeHash);
|
||||||
|
nextHashes[index] = finishHash(hash);
|
||||||
|
}
|
||||||
|
taskStructureHashes.swap(nextHashes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compact dedup key for CPU `c` vs `candidate`: mixes candidateAest, crossbar
|
||||||
|
// usage, and the incremental cpu structure hash. No heap allocation.
|
||||||
|
uint64_t GraphDCP::computeCpuCandidateKey(Time candidateAest, CPU cpu) {
|
||||||
|
uint64_t hash = mixHash(0xd6e8feb86659fd93ULL, static_cast<uint64_t>(candidateAest));
|
||||||
|
hash = mixHash(hash, static_cast<uint64_t>(getCpuCrossbarUsage(cpu)));
|
||||||
|
auto it = cpuStructureHashes.find(cpu);
|
||||||
|
hash = mixHash(hash, it != cpuStructureHashes.end() ? it->second : 0ULL);
|
||||||
|
return finishHash(hash);
|
||||||
|
}
|
||||||
|
|
||||||
// Inserts `task` at `position` on `cpu`, wiring up scheduling edges with the
|
// Inserts `task` at `position` on `cpu`, wiring up scheduling edges with the
|
||||||
// neighbouring tasks and keeping the global topological order consistent.
|
// neighbouring tasks and keeping the global topological order consistent.
|
||||||
TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position) {
|
TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position) {
|
||||||
@@ -164,6 +316,7 @@ TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position)
|
|||||||
task->setCpu(cpu);
|
task->setCpu(cpu);
|
||||||
task->setWeight(scheduledWeight);
|
task->setWeight(scheduledWeight);
|
||||||
reserveTaskCrossbars(cpu, task);
|
reserveTaskCrossbars(cpu, task);
|
||||||
|
cpuStructureHashes[cpu] ^= taskStructureHashes[getNodeIndex(task)];
|
||||||
auto& tasksInCpu = getOrCreateCpuTasks(cpu);
|
auto& tasksInCpu = getOrCreateCpuTasks(cpu);
|
||||||
unsigned int numCpuTasks = tasksInCpu.size();
|
unsigned int numCpuTasks = tasksInCpu.size();
|
||||||
assert(position <= numCpuTasks && "Inserting in a not valid position");
|
assert(position <= numCpuTasks && "Inserting in a not valid position");
|
||||||
@@ -201,6 +354,7 @@ TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position)
|
|||||||
|
|
||||||
void GraphDCP::removeTaskFromCPU(CPU cpu, TaskDCP* task) {
|
void GraphDCP::removeTaskFromCPU(CPU cpu, TaskDCP* task) {
|
||||||
releaseTaskCrossbars(cpu, task);
|
releaseTaskCrossbars(cpu, task);
|
||||||
|
cpuStructureHashes[cpu] ^= taskStructureHashes[getNodeIndex(task)];
|
||||||
task->resetCpu();
|
task->resetCpu();
|
||||||
task->resetWeight();
|
task->resetWeight();
|
||||||
auto& scheduledTasks = getOrCreateCpuTasks(cpu);
|
auto& scheduledTasks = getOrCreateCpuTasks(cpu);
|
||||||
@@ -271,6 +425,21 @@ bool GraphDCP::wouldExhaustCrossbarCapacity(CPU cpu, const TaskDCP* task) const
|
|||||||
return nextUsage >= getCpuCrossbarCapacity();
|
return nextUsage >= getCpuCrossbarCapacity();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t GraphDCP::crossbarsUsed() const {
|
||||||
|
CrossbarUsage crossbarEdge = static_cast<CrossbarUsage>(onnx_mlir::crossbarSize.getValue());
|
||||||
|
CrossbarUsage crossbarArea = crossbarEdge * crossbarEdge;
|
||||||
|
if (crossbarArea == 0)
|
||||||
|
return 0;
|
||||||
|
CrossbarUsage totalArea = 0;
|
||||||
|
for (const auto& [cpu, usage] : cpuCrossbarUsage)
|
||||||
|
totalArea = checkedAdd(totalArea, usage);
|
||||||
|
return static_cast<size_t>(totalArea / crossbarArea);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t GraphDCP::crossbarsAvailable() const {
|
||||||
|
return static_cast<size_t>(lastCpu) * onnx_mlir::crossbarCountInCore.getValue();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AEST / ALST computation
|
// AEST / ALST computation
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -456,9 +625,9 @@ void GraphDCP::updateAestFromTaskWithDescendants(TaskDCP* task, llvm::ArrayRef<T
|
|||||||
for (TaskDCP* descendant : descendantsTopoOrder)
|
for (TaskDCP* descendant : descendantsTopoOrder)
|
||||||
recomputeAest(descendant);
|
recomputeAest(descendant);
|
||||||
|
|
||||||
const bool oldMaxInvalidated = maxCompletionTask != nullptr
|
const bool oldMaxInvalidated =
|
||||||
&& (maxCompletionTask == task
|
maxCompletionTask != nullptr
|
||||||
|| llvm::is_contained(descendantsTopoOrder, maxCompletionTask));
|
&& (maxCompletionTask == task || llvm::is_contained(descendantsTopoOrder, maxCompletionTask));
|
||||||
if (oldMaxInvalidated) {
|
if (oldMaxInvalidated) {
|
||||||
// The pre-update max came from a modified task; its completion has moved
|
// The pre-update max came from a modified task; its completion has moved
|
||||||
// upward, so modifiedMaxCompletion is an upper bound covering it. The
|
// upward, so modifiedMaxCompletion is an upper bound covering it. The
|
||||||
@@ -523,9 +692,9 @@ bool GraphDCP::tryUpdateAestWithinBudget(TaskDCP* task,
|
|||||||
if (!process(descendant))
|
if (!process(descendant))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
const bool oldMaxInvalidated = maxCompletionTask != nullptr
|
const bool oldMaxInvalidated =
|
||||||
&& (maxCompletionTask == task
|
maxCompletionTask != nullptr
|
||||||
|| llvm::is_contained(descendantsTopoOrder, maxCompletionTask));
|
&& (maxCompletionTask == task || llvm::is_contained(descendantsTopoOrder, maxCompletionTask));
|
||||||
if (oldMaxInvalidated) {
|
if (oldMaxInvalidated) {
|
||||||
dcpl = modifiedMaxCompletion;
|
dcpl = modifiedMaxCompletion;
|
||||||
maxCompletion = modifiedMaxCompletion;
|
maxCompletion = modifiedMaxCompletion;
|
||||||
@@ -546,6 +715,109 @@ bool GraphDCP::tryUpdateAestWithinBudget(TaskDCP* task,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Incrementally refreshes ALST after `task` was placed. The set of nodes whose
|
||||||
|
// ALST is structurally affected by the insertion is exactly
|
||||||
|
// `relations.ancestors ∪ {task}`: the task's outgoing transfer costs to
|
||||||
|
// same-CPU real children become 0, and new scheduling edges create parent
|
||||||
|
// relationships between `task` and its same-CPU neighbors. Every other node
|
||||||
|
// keeps its relative distance to the sink boundary and only absorbs the
|
||||||
|
// signed DCPL delta captured between `oldDcpl` and the now-updated `dcpl`.
|
||||||
|
void GraphDCP::updateAlstFromScheduledTask(TaskDCP* task, const CandidateRelations& relations, Time oldDcpl) {
|
||||||
|
Time newDcpl = getDcpl();
|
||||||
|
// If the AEST update saturated dcpl (e.g. rescue placement on a
|
||||||
|
// crossbar-exhausted CPU sets task weight to UINT64_MAX), the shift delta
|
||||||
|
// would be meaningless. Fall back to a full recompute for this step only.
|
||||||
|
if (newDcpl == std::numeric_limits<Time>::max()) {
|
||||||
|
initAlst();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (newDcpl != oldDcpl) {
|
||||||
|
const bool increased = newDcpl > oldDcpl;
|
||||||
|
const Time delta = increased ? (newDcpl - oldDcpl) : (oldDcpl - newDcpl);
|
||||||
|
for (TaskDCP& node : topologicalOrder) {
|
||||||
|
if (&node == task || relations.ancestors.contains(&node))
|
||||||
|
continue;
|
||||||
|
Time alst = node.getAlst();
|
||||||
|
node.setAlst(increased ? addOrMax(alst, delta) : subtractOrZero(alst, delta));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto recomputeAlst = [&](TaskDCP* node) {
|
||||||
|
Time minAlst = std::numeric_limits<Time>::max();
|
||||||
|
if (!node->hasChildren())
|
||||||
|
minAlst = subtractOrZero(newDcpl, node->getWeight());
|
||||||
|
for (const Edge& childEdge : node->children)
|
||||||
|
minAlst = std::min(minAlst,
|
||||||
|
subtractOrZero(childEdge.first->getAlst(),
|
||||||
|
addOrMax(node->getWeight(), getTransferCost(node, childEdge.first))));
|
||||||
|
node->setAlst(minAlst);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Walk the backward cone with a pending-children counter so that every
|
||||||
|
// ancestor is recomputed only after all of its affected children have
|
||||||
|
// been refreshed. This is resilient to staleness in the global
|
||||||
|
// `topologicalOrder` relative to freshly added scheduling edges.
|
||||||
|
llvm::DenseSet<TaskDCP*> affected = relations.ancestors;
|
||||||
|
affected.insert(task);
|
||||||
|
llvm::DenseMap<TaskDCP*, int> pendingAffectedChildren;
|
||||||
|
pendingAffectedChildren.reserve(affected.size());
|
||||||
|
std::vector<TaskDCP*> worklist;
|
||||||
|
worklist.reserve(affected.size());
|
||||||
|
for (TaskDCP* node : affected) {
|
||||||
|
int count = 0;
|
||||||
|
for (const Edge& childEdge : node->children)
|
||||||
|
if (affected.contains(childEdge.first))
|
||||||
|
count++;
|
||||||
|
pendingAffectedChildren[node] = count;
|
||||||
|
if (count == 0)
|
||||||
|
worklist.push_back(node);
|
||||||
|
}
|
||||||
|
while (!worklist.empty()) {
|
||||||
|
TaskDCP* node = worklist.back();
|
||||||
|
worklist.pop_back();
|
||||||
|
recomputeAlst(node);
|
||||||
|
for (const Edge& parentEdge : node->parents) {
|
||||||
|
if (!affected.contains(parentEdge.first))
|
||||||
|
continue;
|
||||||
|
auto it = pendingAffectedChildren.find(parentEdge.first);
|
||||||
|
assert(it != pendingAffectedChildren.end());
|
||||||
|
if (--it->second == 0)
|
||||||
|
worklist.push_back(parentEdge.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Opt-in consistency check: verifies the incremental ALST result against a
|
||||||
|
// full initAlst() recomputation. Very expensive (O(V+E) per placement) - only
|
||||||
|
// enable when investigating suspected drift.
|
||||||
|
#ifdef DCP_DEBUG_CHECK_ALST
|
||||||
|
std::vector<Time> afterIncremental(nodes.size());
|
||||||
|
for (size_t i = 0; i < nodes.size(); ++i)
|
||||||
|
afterIncremental[i] = nodes[i].getAlst();
|
||||||
|
initAlst();
|
||||||
|
bool mismatched = false;
|
||||||
|
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||||
|
if (afterIncremental[i] != nodes[i].getAlst()) {
|
||||||
|
if (!mismatched) {
|
||||||
|
llvm::errs() << "[alst-mismatch] placed=" << getNodeIndex(task) << " oldDcpl=" << oldDcpl
|
||||||
|
<< " newDcpl=" << newDcpl << " ancestors={";
|
||||||
|
for (TaskDCP* a : relations.ancestors)
|
||||||
|
llvm::errs() << getNodeIndex(a) << ",";
|
||||||
|
llvm::errs() << "}\n";
|
||||||
|
mismatched = true;
|
||||||
|
}
|
||||||
|
llvm::errs() << " node=" << i << " incremental=" << afterIncremental[i] << " full=" << nodes[i].getAlst()
|
||||||
|
<< " weight=" << nodes[i].getWeight()
|
||||||
|
<< " cpu=" << (nodes[i].isScheduled() ? (int) *nodes[i].getCpu() : -1) << " children=[";
|
||||||
|
for (const Edge& e : nodes[i].children)
|
||||||
|
llvm::errs() << getNodeIndex(e.first) << (e.isScheduling ? "s" : "")
|
||||||
|
<< "(tc=" << getTransferCost(&nodes[i], e.first) << ",alst=" << e.first->getAlst() << "),";
|
||||||
|
llvm::errs() << "]\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// Computes a localised ALST: only ancestors of the candidate (plus the
|
// Computes a localised ALST: only ancestors of the candidate (plus the
|
||||||
// candidate itself) get recomputed, every other task keeps its current ALST.
|
// candidate itself) get recomputed, every other task keeps its current ALST.
|
||||||
// Processes nodes in reverse dependency order using a pending-children
|
// Processes nodes in reverse dependency order using a pending-children
|
||||||
@@ -905,32 +1177,6 @@ GraphDCP::FindSlot GraphDCP::findSlotWithFixedFinalTime(
|
|||||||
// Candidate selection and processor assignment
|
// Candidate selection and processor assignment
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Lowest slack wins; earliest AEST breaks ties. Critical-path tasks (zero
|
|
||||||
// slack) naturally float to the front.
|
|
||||||
TaskDCP* GraphDCP::findCandidate(const std::vector<TaskDCP*>& readyNodes) {
|
|
||||||
auto findBestNode = [](auto lft, auto rgt) {
|
|
||||||
Time leftSlack = slackOrZero((*lft)->getAest(), (*lft)->getAlst());
|
|
||||||
Time rightSlack = slackOrZero((*rgt)->getAest(), (*rgt)->getAlst());
|
|
||||||
if (leftSlack < rightSlack)
|
|
||||||
return lft;
|
|
||||||
if (rightSlack < leftSlack)
|
|
||||||
return rgt;
|
|
||||||
if ((*lft)->getAest() < (*rgt)->getAest())
|
|
||||||
return lft;
|
|
||||||
return rgt;
|
|
||||||
};
|
|
||||||
|
|
||||||
assert(!readyNodes.empty() && "expected at least one ready node");
|
|
||||||
auto validNode = readyNodes.begin();
|
|
||||||
auto bestNode = validNode;
|
|
||||||
|
|
||||||
while (validNode != readyNodes.end()) {
|
|
||||||
bestNode = findBestNode(validNode, bestNode);
|
|
||||||
std::advance(validNode, 1);
|
|
||||||
}
|
|
||||||
return *bestNode;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Picks the best CPU + slot for `candidate`:
|
// Picks the best CPU + slot for `candidate`:
|
||||||
// * Phase 1 (parallel, read-only): call findSlot on every candidate CPU.
|
// * Phase 1 (parallel, read-only): call findSlot on every candidate CPU.
|
||||||
// * Phase 2 (sequential): process CPUs in ascending slot.aest order. For
|
// * Phase 2 (sequential): process CPUs in ascending slot.aest order. For
|
||||||
@@ -939,7 +1185,7 @@ TaskDCP* GraphDCP::findCandidate(const std::vector<TaskDCP*>& readyNodes) {
|
|||||||
// evaluate a slot for the smallest-slack child, then roll back.
|
// evaluate a slot for the smallest-slack child, then roll back.
|
||||||
// * Rescue (sequential): if nothing fit, grow the CPU count if allowed,
|
// * Rescue (sequential): if nothing fit, grow the CPU count if allowed,
|
||||||
// otherwise pick the CPU that leads to the smallest DCPL increase.
|
// otherwise pick the CPU that leads to the smallest DCPL increase.
|
||||||
void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
||||||
CandidateRelations relations = dcp_graph::computeCandidateRelations(candidate);
|
CandidateRelations relations = dcp_graph::computeCandidateRelations(candidate);
|
||||||
relations.descendantsTopoOrder.reserve(relations.descendants.size());
|
relations.descendantsTopoOrder.reserve(relations.descendants.size());
|
||||||
for (auto it = candidate->getTopologicalIterator(); it != topologicalOrder.end(); ++it) {
|
for (auto it = candidate->getTopologicalIterator(); it != topologicalOrder.end(); ++it) {
|
||||||
@@ -959,22 +1205,43 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
|||||||
const CrossbarUsage candidateFootprint = getTaskCrossbarFootprint(candidate);
|
const CrossbarUsage candidateFootprint = getTaskCrossbarFootprint(candidate);
|
||||||
const bool candidateHasCrossbar = candidateFootprint != 0;
|
const bool candidateHasCrossbar = candidateFootprint != 0;
|
||||||
const CrossbarUsage cpuCapacity = candidateHasCrossbar ? getCpuCrossbarCapacity() : 0;
|
const CrossbarUsage cpuCapacity = candidateHasCrossbar ? getCpuCrossbarCapacity() : 0;
|
||||||
|
DCP_DEBUG_IF(auto dedupStart = std::chrono::steady_clock::now();)
|
||||||
|
CpuAestCache cpuAests = computeCpuAestCache(candidate);
|
||||||
|
DCP_DEBUG_IF(const bool checkCpuAestCache = std::getenv("DCP_CHECK_CPU_AEST_CACHE") != nullptr;)
|
||||||
|
llvm::SmallDenseSet<uint64_t, 32> seenProcessorKeys;
|
||||||
|
seenProcessorKeys.reserve(static_cast<size_t>(topCpu + 1));
|
||||||
for (CPU c = 0; c <= topCpu; c++) {
|
for (CPU c = 0; c <= topCpu; c++) {
|
||||||
if (candidateHasCrossbar && c != getLastCpu()) {
|
if (candidateHasCrossbar && c != getLastCpu()) {
|
||||||
CrossbarUsage nextUsage = checkedAdd(getCpuCrossbarUsage(c), candidateFootprint);
|
CrossbarUsage nextUsage = checkedAdd(getCpuCrossbarUsage(c), candidateFootprint);
|
||||||
if (nextUsage >= cpuCapacity)
|
if (nextUsage >= cpuCapacity)
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
Time candidateAest = cpuAests.get(c);
|
||||||
|
DCP_DEBUG_IF(if (checkCpuAestCache) {
|
||||||
|
Time recomputedAest = computeAestOnCpu(candidate, c);
|
||||||
|
if (candidateAest != recomputedAest) {
|
||||||
|
std::fprintf(stderr,
|
||||||
|
"[DCP_CHECK_CPU_AEST_CACHE] mismatch candidate=%zu cpu=%d cached=%llu recomputed=%llu\n",
|
||||||
|
getNodeIndex(candidate),
|
||||||
|
c,
|
||||||
|
static_cast<unsigned long long>(candidateAest),
|
||||||
|
static_cast<unsigned long long>(recomputedAest));
|
||||||
|
llvm::report_fatal_error("DCP CPU AEST cache mismatch");
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if (!seenProcessorKeys.insert(computeCpuCandidateKey(candidateAest, c)).second)
|
||||||
|
continue;
|
||||||
processors.push_back(c);
|
processors.push_back(c);
|
||||||
}
|
}
|
||||||
|
DCP_DEBUG_IF(gSelectTimers.dedup +=
|
||||||
|
std::chrono::duration<double>(std::chrono::steady_clock::now() - dedupStart).count();)
|
||||||
if (processors.empty()) {
|
if (processors.empty()) {
|
||||||
CPU bestCpu = canCreateNewCpu ? getLastCpu() : 0;
|
// processors.empty() implies !canCreateNewCpu: a fresh CPU always passes
|
||||||
FindSlot bestSlot = {computeAestOnCpu(candidate, bestCpu), static_cast<int>(getOrCreateCpuTasks(bestCpu).size())};
|
// the crossbar filter and would have been added. Reaching here means every
|
||||||
if (canCreateNewCpu)
|
// existing CPU is crossbar-exhausted and the task requires crossbar
|
||||||
incrementLastCpu();
|
// capacity — the placement is impossible.
|
||||||
insertTaskInCPU(bestCpu, candidate, bestSlot.index);
|
llvm::report_fatal_error("DCP scheduler: crossbar capacity exhausted on all CPUs; "
|
||||||
return;
|
"cannot schedule task that requires crossbar allocation");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Phase 1: parallel findSlot sweep (read-only over graph state).
|
// Phase 1: parallel findSlot sweep (read-only over graph state).
|
||||||
@@ -1000,21 +1267,20 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
|||||||
for (size_t i = 0; i < processors.size(); ++i)
|
for (size_t i = 0; i < processors.size(); ++i)
|
||||||
sweep(i);
|
sweep(i);
|
||||||
DCP_DEBUG_IF(gSelectTimers.findSlot +=
|
DCP_DEBUG_IF(gSelectTimers.findSlot +=
|
||||||
std::chrono::duration<double>(std::chrono::steady_clock::now() - sweepStart).count();)
|
std::chrono::duration<double>(std::chrono::steady_clock::now() - sweepStart).count();)
|
||||||
|
|
||||||
#ifdef DCP_DEBUG_ENABLED
|
#ifdef DCP_DEBUG_ENABLED
|
||||||
{
|
{
|
||||||
static bool reported = false;
|
static bool reported = false;
|
||||||
if (!reported) {
|
if (!reported) {
|
||||||
reported = true;
|
reported = true;
|
||||||
std::fprintf(stderr,
|
std::fprintf(
|
||||||
"[dcp] selectProcessor parallel sweep: context=%p mt=%d procs=%zu pool=%u\n",
|
stderr,
|
||||||
(void*) context,
|
"[dcp] selectProcessor parallel sweep: context=%p mt=%d procs=%zu pool=%u\n",
|
||||||
context != nullptr ? (int) context->isMultithreadingEnabled() : -1,
|
(void*) context,
|
||||||
processors.size(),
|
context != nullptr ? (int) context->isMultithreadingEnabled() : -1,
|
||||||
context != nullptr && context->isMultithreadingEnabled()
|
processors.size(),
|
||||||
? context->getThreadPool().getMaxConcurrency()
|
context != nullptr && context->isMultithreadingEnabled() ? context->getThreadPool().getMaxConcurrency() : 0u);
|
||||||
: 0u);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@@ -1055,9 +1321,10 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
|||||||
DCP_DEBUG_IF(auto t2 = std::chrono::steady_clock::now();)
|
DCP_DEBUG_IF(auto t2 = std::chrono::steady_clock::now();)
|
||||||
Weight candidateWeight = candidate->computeWeightOnCpu(this, currentCpu);
|
Weight candidateWeight = candidate->computeWeightOnCpu(this, currentCpu);
|
||||||
Time candidateCompletion = addOrMax(slot.aest, candidateWeight);
|
Time candidateCompletion = addOrMax(slot.aest, candidateWeight);
|
||||||
bool skip = (!emptyCpu && candidateCompletion > currentDcpl)
|
bool skip =
|
||||||
|| addOrMax(slot.aest, candidateCompletion) >= bestComposite;
|
(!emptyCpu && candidateCompletion > currentDcpl) || addOrMax(slot.aest, candidateCompletion) >= bestComposite;
|
||||||
DCP_DEBUG_IF(gSelectTimers.precheck += std::chrono::duration<double>(std::chrono::steady_clock::now() - t2).count();)
|
DCP_DEBUG_IF(gSelectTimers.precheck +=
|
||||||
|
std::chrono::duration<double>(std::chrono::steady_clock::now() - t2).count();)
|
||||||
if (skip)
|
if (skip)
|
||||||
continue;
|
continue;
|
||||||
DCP_DEBUG_IF(++gSelectTimers.passedPrecheck;)
|
DCP_DEBUG_IF(++gSelectTimers.passedPrecheck;)
|
||||||
@@ -1073,8 +1340,8 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
|||||||
scheduleSnapshot = dcp_graph::captureLocalScheduleState(
|
scheduleSnapshot = dcp_graph::captureLocalScheduleState(
|
||||||
candidate, relations.descendants, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
|
candidate, relations.descendants, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
|
||||||
taskInsertion = insertTaskInCPU(currentCpu, candidate, slot.index);
|
taskInsertion = insertTaskInCPU(currentCpu, candidate, slot.index);
|
||||||
bool withinBudget = tryUpdateAestWithinBudget(
|
bool withinBudget =
|
||||||
candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder), currentDcpl);
|
tryUpdateAestWithinBudget(candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder), currentDcpl);
|
||||||
if (!withinBudget) {
|
if (!withinBudget) {
|
||||||
DCP_DEBUG_IF(auto t4 = std::chrono::steady_clock::now();)
|
DCP_DEBUG_IF(auto t4 = std::chrono::steady_clock::now();)
|
||||||
taskInsertion.rollBack();
|
taskInsertion.rollBack();
|
||||||
@@ -1087,7 +1354,7 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
DCP_DEBUG_IF(gSelectTimers.snapshotInsertUpdate +=
|
DCP_DEBUG_IF(gSelectTimers.snapshotInsertUpdate +=
|
||||||
std::chrono::duration<double>(std::chrono::steady_clock::now() - t3).count();)
|
std::chrono::duration<double>(std::chrono::steady_clock::now() - t3).count();)
|
||||||
DCP_DEBUG_IF(++gSelectTimers.passedDcpl;)
|
DCP_DEBUG_IF(++gSelectTimers.passedDcpl;)
|
||||||
|
|
||||||
// Pick the tightest unscheduled child (smallest slack) and measure what
|
// Pick the tightest unscheduled child (smallest slack) and measure what
|
||||||
@@ -1135,7 +1402,7 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
|||||||
dcp_graph::restoreLocalScheduleState(
|
dcp_graph::restoreLocalScheduleState(
|
||||||
scheduleSnapshot, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
|
scheduleSnapshot, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
|
||||||
DCP_DEBUG_IF(gSelectTimers.rollbackRestore +=
|
DCP_DEBUG_IF(gSelectTimers.rollbackRestore +=
|
||||||
std::chrono::duration<double>(std::chrono::steady_clock::now() - t6).count();)
|
std::chrono::duration<double>(std::chrono::steady_clock::now() - t6).count();)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1150,7 +1417,9 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
|||||||
else {
|
else {
|
||||||
Time bestDcpl = std::numeric_limits<Time>::max();
|
Time bestDcpl = std::numeric_limits<Time>::max();
|
||||||
Time currentDcpl = getDcpl();
|
Time currentDcpl = getDcpl();
|
||||||
for (CPU c = 0; c < getLastCpu(); c++) {
|
for (CPU c : processors) {
|
||||||
|
if (c == getLastCpu())
|
||||||
|
continue;
|
||||||
auto slot = findSlot(candidate, c, false, relations);
|
auto slot = findSlot(candidate, c, false, relations);
|
||||||
if (slot.aest == std::numeric_limits<Time>::max())
|
if (slot.aest == std::numeric_limits<Time>::max())
|
||||||
slot = findSlot(candidate, c, true, relations);
|
slot = findSlot(candidate, c, true, relations);
|
||||||
@@ -1159,8 +1428,7 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
|||||||
// Cheap lower bound: post-insertion DCPL is at least max(currentDcpl,
|
// Cheap lower bound: post-insertion DCPL is at least max(currentDcpl,
|
||||||
// candidate completion on this slot). Skip CPUs already worse than
|
// candidate completion on this slot). Skip CPUs already worse than
|
||||||
// the best seen.
|
// the best seen.
|
||||||
Time lowerBound =
|
Time lowerBound = std::max(currentDcpl, addOrMax(slot.aest, candidate->computeWeightOnCpu(this, c)));
|
||||||
std::max(currentDcpl, addOrMax(slot.aest, candidate->computeWeightOnCpu(this, c)));
|
|
||||||
if (lowerBound >= bestDcpl)
|
if (lowerBound >= bestDcpl)
|
||||||
continue;
|
continue;
|
||||||
auto snapshot = dcp_graph::captureLocalScheduleState(
|
auto snapshot = dcp_graph::captureLocalScheduleState(
|
||||||
@@ -1169,23 +1437,37 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
|||||||
updateAestFromTaskWithDescendants(candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder));
|
updateAestFromTaskWithDescendants(candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder));
|
||||||
Time candidateDcpl = getDcpl();
|
Time candidateDcpl = getDcpl();
|
||||||
taskInsertion.rollBack();
|
taskInsertion.rollBack();
|
||||||
dcp_graph::restoreLocalScheduleState(
|
dcp_graph::restoreLocalScheduleState(snapshot, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
|
||||||
snapshot, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
|
|
||||||
if (candidateDcpl < bestDcpl) {
|
if (candidateDcpl < bestDcpl) {
|
||||||
bestDcpl = candidateDcpl;
|
bestDcpl = candidateDcpl;
|
||||||
bestCpu = c;
|
bestCpu = c;
|
||||||
bestSlot = slot;
|
bestSlot = slot;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (bestCpu == -1) {
|
if (bestCpu == -1)
|
||||||
bestCpu = 0;
|
llvm::report_fatal_error("DCP scheduler: no valid slot found for task on any eligible CPU — "
|
||||||
bestSlot = {computeAestOnCpu(candidate, bestCpu), static_cast<int>(getOrCreateCpuTasks(bestCpu).size())};
|
"all slots are blocked by already-placed descendants");
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (bestCpu == getLastCpu() && getLastCpu() < maxCpuCount)
|
if (bestCpu == getLastCpu() && getLastCpu() < maxCpuCount)
|
||||||
incrementLastCpu();
|
incrementLastCpu();
|
||||||
insertTaskInCPU(bestCpu, candidate, bestSlot.index);
|
insertTaskInCPU(bestCpu, candidate, bestSlot.index);
|
||||||
|
|
||||||
|
// Incremental AEST/ALST refresh replacing the full initAest/initAlst that
|
||||||
|
// used to run after every placement. Post-insertion relations pick up any
|
||||||
|
// new scheduling-edge ancestors/descendants introduced by the insertion.
|
||||||
|
Time oldDcpl = getDcpl();
|
||||||
|
CandidateRelations postRelations = dcp_graph::computeCandidateRelations(candidate);
|
||||||
|
llvm::SmallVector<TaskDCP*, 32> postDescendantsTopoOrder;
|
||||||
|
postDescendantsTopoOrder.reserve(postRelations.descendants.size());
|
||||||
|
for (auto it = candidate->getTopologicalIterator(); it != topologicalOrder.end(); ++it) {
|
||||||
|
TaskDCP* current = &*it;
|
||||||
|
if (current != candidate && postRelations.descendants.contains(current))
|
||||||
|
postDescendantsTopoOrder.push_back(current);
|
||||||
|
}
|
||||||
|
updateAestFromTaskWithDescendants(candidate, llvm::ArrayRef<TaskDCP*>(postDescendantsTopoOrder));
|
||||||
|
updateAlstFromScheduledTask(candidate, postRelations, oldDcpl);
|
||||||
|
return postRelations;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@@ -1194,61 +1476,99 @@ void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
|
|||||||
|
|
||||||
void GraphDCP::runDcp() {
|
void GraphDCP::runDcp() {
|
||||||
initTopological();
|
initTopological();
|
||||||
|
initTaskStructureHashes();
|
||||||
initAest();
|
initAest();
|
||||||
initAlst();
|
initAlst();
|
||||||
dumpDot();
|
dumpDot();
|
||||||
|
|
||||||
dcp_graph::DcpProgressLogger progressLogger(nodes.size());
|
dcp_graph::DcpProgressLogger progressLogger(nodes.size());
|
||||||
llvm::DenseMap<TaskDCP*, int> unscheduledParents;
|
llvm::DenseMap<TaskDCP*, int> unscheduledParents;
|
||||||
std::vector<TaskDCP*> readyNodes;
|
|
||||||
readyNodes.reserve(nodes.size());
|
// Min-heap over ready tasks: tightest slack first, earliest AEST as tiebreak.
|
||||||
|
// Lazy deletion: when AEST/ALST change after a placement, fresh entries are
|
||||||
|
// pushed for the affected tasks. Stale ones are detected on pop by comparing
|
||||||
|
// stored vs current (slack, aest) and re-pushed with the current values.
|
||||||
|
struct ReadyEntry {
|
||||||
|
Time slack;
|
||||||
|
Time aest;
|
||||||
|
TaskDCP* task;
|
||||||
|
bool operator>(const ReadyEntry& other) const {
|
||||||
|
if (slack != other.slack)
|
||||||
|
return slack > other.slack;
|
||||||
|
return aest > other.aest;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
std::priority_queue<ReadyEntry, std::vector<ReadyEntry>, std::greater<ReadyEntry>> readyQueue;
|
||||||
|
size_t readyCount = 0;
|
||||||
|
|
||||||
|
auto pushReady = [&](TaskDCP* node) {
|
||||||
|
readyQueue.push({slackOrZero(node->getAest(), node->getAlst()), node->getAest(), node});
|
||||||
|
};
|
||||||
|
|
||||||
for (auto& node : nodes) {
|
for (auto& node : nodes) {
|
||||||
int dependencyParents = dcp_graph::countDependencyParents(&node);
|
int dependencyParents = dcp_graph::countDependencyParents(&node);
|
||||||
unscheduledParents[&node] = dependencyParents;
|
unscheduledParents[&node] = dependencyParents;
|
||||||
if (dependencyParents == 0)
|
if (dependencyParents == 0) {
|
||||||
readyNodes.push_back(&node);
|
pushReady(&node);
|
||||||
|
++readyCount;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
progressLogger.printStart(readyNodes.size());
|
size_t xbarsCapacity = static_cast<size_t>(maxCpuCount) * onnx_mlir::crossbarCountInCore.getValue();
|
||||||
|
progressLogger.printStart(readyCount, maxCpuCount, xbarsCapacity);
|
||||||
|
|
||||||
while (!readyNodes.empty()) {
|
while (readyCount > 0) {
|
||||||
DCP_DEBUG_IF(auto findStart = std::chrono::steady_clock::now();)
|
// Pop with lazy deletion: skip stale entries and re-push with current values.
|
||||||
TaskDCP* candidate = findCandidate(readyNodes);
|
TaskDCP* candidate = nullptr;
|
||||||
DCP_DEBUG_IF(progressLogger.recordFindDuration(
|
while (!readyQueue.empty()) {
|
||||||
std::chrono::duration<double>(std::chrono::steady_clock::now() - findStart).count());)
|
auto entry = readyQueue.top();
|
||||||
fastRemove(readyNodes, candidate);
|
readyQueue.pop();
|
||||||
|
Time curSlack = slackOrZero(entry.task->getAest(), entry.task->getAlst());
|
||||||
|
Time curAest = entry.task->getAest();
|
||||||
|
if (entry.slack == curSlack && entry.aest == curAest) {
|
||||||
|
candidate = entry.task;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
readyQueue.push({curSlack, curAest, entry.task});
|
||||||
|
}
|
||||||
|
assert(candidate != nullptr && "readyCount > 0 but heap exhausted");
|
||||||
|
--readyCount;
|
||||||
|
|
||||||
DCP_DEBUG_IF(auto selectStart = std::chrono::steady_clock::now();)
|
DCP_DEBUG_IF(auto selectStart = std::chrono::steady_clock::now();)
|
||||||
selectProcessor(candidate, candidate->isCriticalPath());
|
CandidateRelations postRelations = selectProcessor(candidate, candidate->isCriticalPath());
|
||||||
DCP_DEBUG_IF(
|
DCP_DEBUG_IF(
|
||||||
double selectSeconds = std::chrono::duration<double>(std::chrono::steady_clock::now() - selectStart).count();
|
double selectSeconds = std::chrono::duration<double>(std::chrono::steady_clock::now() - selectStart).count();
|
||||||
progressLogger.recordSelectDuration(selectSeconds);
|
progressLogger.recordSelectDuration(selectSeconds);
|
||||||
progressLogger.maybePrintSlowCandidate(getNodeIndex(candidate), selectSeconds, readyNodes.size(), getLastCpu());
|
progressLogger.maybePrintSlowCandidate(getNodeIndex(candidate), selectSeconds, readyCount, getLastCpu());)
|
||||||
)
|
|
||||||
|
// Proactively refresh the heap priority for ready nodes whose AEST or ALST
|
||||||
|
// changed: ancestors had their ALST individually recomputed; descendants had
|
||||||
|
// their AEST bumped. Both may now sort differently than their stale entries.
|
||||||
|
for (TaskDCP* node : postRelations.ancestors)
|
||||||
|
if (!node->isScheduled() && unscheduledParents[node] == 0)
|
||||||
|
pushReady(node);
|
||||||
|
for (TaskDCP* node : postRelations.descendants)
|
||||||
|
if (!node->isScheduled() && unscheduledParents[node] == 0)
|
||||||
|
pushReady(node);
|
||||||
|
|
||||||
DCP_DEBUG_IF(auto updateStart = std::chrono::steady_clock::now();)
|
|
||||||
initAest();
|
|
||||||
initAlst();
|
|
||||||
DCP_DEBUG_IF(progressLogger.recordUpdateDuration(
|
|
||||||
std::chrono::duration<double>(std::chrono::steady_clock::now() - updateStart).count());)
|
|
||||||
progressLogger.advanceCompleted();
|
progressLogger.advanceCompleted();
|
||||||
progressLogger.printProgress(readyNodes.size(), getLastCpu(), "recompute", false);
|
progressLogger.printProgress(readyCount, getLastCpu(), maxCpuCount, crossbarsUsed(), crossbarsAvailable(), false);
|
||||||
|
|
||||||
for (const auto& childEdge : candidate->children) {
|
for (const auto& childEdge : candidate->children) {
|
||||||
if (childEdge.isScheduling || childEdge.first->isScheduled())
|
if (childEdge.isScheduling || childEdge.first->isScheduled())
|
||||||
continue;
|
continue;
|
||||||
int& dependencyParents = unscheduledParents[childEdge.first];
|
int& dependencyParents = unscheduledParents[childEdge.first];
|
||||||
assert(dependencyParents > 0 && "dependency parent count must stay positive");
|
assert(dependencyParents > 0 && "dependency parent count must stay positive");
|
||||||
dependencyParents--;
|
--dependencyParents;
|
||||||
if (dependencyParents == 0)
|
if (dependencyParents == 0) {
|
||||||
readyNodes.push_back(childEdge.first);
|
pushReady(childEdge.first);
|
||||||
|
++readyCount;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
DCP_DEBUG_IF(
|
DCP_DEBUG_IF(++gSelectTimers.tasksProcessed;
|
||||||
++gSelectTimers.tasksProcessed;
|
if (std::getenv("DCP_SELECT_PROFILE") && (gSelectTimers.tasksProcessed % 100 == 0))
|
||||||
if (std::getenv("DCP_SELECT_PROFILE") && (gSelectTimers.tasksProcessed % 100 == 0))
|
gSelectTimers.dump("tick");)
|
||||||
gSelectTimers.dump("tick");
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
progressLogger.printProgress(readyNodes.size(), getLastCpu(), "done", true);
|
progressLogger.printProgress(0, getLastCpu(), maxCpuCount, crossbarsUsed(), crossbarsAvailable(), true);
|
||||||
dumpDot();
|
dumpDot();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/DenseSet.h"
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@@ -48,8 +49,10 @@ private:
|
|||||||
|
|
||||||
std::vector<TaskDCP> nodes;
|
std::vector<TaskDCP> nodes;
|
||||||
onnx_mlir::LabeledList<TaskDCP> topologicalOrder;
|
onnx_mlir::LabeledList<TaskDCP> topologicalOrder;
|
||||||
|
std::vector<uint64_t> taskStructureHashes;
|
||||||
std::vector<CpuTaskList> cpuTasks;
|
std::vector<CpuTaskList> cpuTasks;
|
||||||
std::unordered_map<CPU, CrossbarUsage> cpuCrossbarUsage;
|
std::unordered_map<CPU, CrossbarUsage> cpuCrossbarUsage;
|
||||||
|
llvm::DenseMap<CPU, uint64_t> cpuStructureHashes;
|
||||||
CPU lastCpu = 0;
|
CPU lastCpu = 0;
|
||||||
long long flag = 1;
|
long long flag = 1;
|
||||||
Time dcpl = 0;
|
Time dcpl = 0;
|
||||||
@@ -70,6 +73,7 @@ private:
|
|||||||
|
|
||||||
void initAest();
|
void initAest();
|
||||||
void initAlst();
|
void initAlst();
|
||||||
|
void initTaskStructureHashes();
|
||||||
|
|
||||||
Time computeAestOnCpu(TaskDCP* task, CPU cpu);
|
Time computeAestOnCpu(TaskDCP* task, CPU cpu);
|
||||||
Time computeDcplOnCpu(TaskDCP* task, CPU cpu);
|
Time computeDcplOnCpu(TaskDCP* task, CPU cpu);
|
||||||
@@ -83,9 +87,15 @@ private:
|
|||||||
// `dcplBudget`, signalling that the new DCPL would exceed the budget.
|
// `dcplBudget`, signalling that the new DCPL would exceed the budget.
|
||||||
// Returns true iff the full propagation completed without exceeding the
|
// Returns true iff the full propagation completed without exceeding the
|
||||||
// budget. Uses the caller's snapshot to restore AEST on the aborted tail.
|
// budget. Uses the caller's snapshot to restore AEST on the aborted tail.
|
||||||
bool tryUpdateAestWithinBudget(TaskDCP* task,
|
bool tryUpdateAestWithinBudget(TaskDCP* task, llvm::ArrayRef<TaskDCP*> descendantsTopoOrder, Time dcplBudget);
|
||||||
llvm::ArrayRef<TaskDCP*> descendantsTopoOrder,
|
|
||||||
Time dcplBudget);
|
// Incrementally refreshes ALST after `task` has been scheduled. Nodes
|
||||||
|
// outside the backward cone (`relations.ancestors` plus `task`) retain
|
||||||
|
// their relative distance to the sink boundary and only absorb the signed
|
||||||
|
// DCPL delta (`newDcpl - oldDcpl`). `task` itself and its ancestors are
|
||||||
|
// recomputed in reverse topological order so that new same-CPU transfer
|
||||||
|
// costs (now zero) and scheduling-edge children are reflected.
|
||||||
|
void updateAlstFromScheduledTask(TaskDCP* task, const CandidateRelations& relations, Time oldDcpl);
|
||||||
|
|
||||||
void initTopological();
|
void initTopological();
|
||||||
void topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint, TaskInsertion* insertion = nullptr);
|
void topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint, TaskInsertion* insertion = nullptr);
|
||||||
@@ -94,8 +104,11 @@ private:
|
|||||||
llvm::DenseMap<TaskDCP*, Time> computeAlst(TaskDCP* task, CPU cpu, const CandidateRelations& relations);
|
llvm::DenseMap<TaskDCP*, Time> computeAlst(TaskDCP* task, CPU cpu, const CandidateRelations& relations);
|
||||||
size_t getNodeIndex(const TaskDCP* task) const;
|
size_t getNodeIndex(const TaskDCP* task) const;
|
||||||
|
|
||||||
TaskDCP* findCandidate(const std::vector<TaskDCP*>& readyNodes);
|
// Returns a compact dedup key for CPU `c` when evaluating `candidate`:
|
||||||
void selectProcessor(TaskDCP* candidate, bool push);
|
// mixes candidateAest, crossbar usage, and the incremental cpu structure
|
||||||
|
// hash into a single uint64_t. Zero heap allocation.
|
||||||
|
uint64_t computeCpuCandidateKey(Time candidateAest, CPU cpu);
|
||||||
|
CandidateRelations selectProcessor(TaskDCP* candidate, bool push);
|
||||||
CPU getLastCpu() const { return lastCpu; }
|
CPU getLastCpu() const { return lastCpu; }
|
||||||
void incrementLastCpu() { lastCpu++; }
|
void incrementLastCpu() { lastCpu++; }
|
||||||
FindSlot findSlot(TaskDCP* candidate, CPU cpu, bool push, const CandidateRelations& relations);
|
FindSlot findSlot(TaskDCP* candidate, CPU cpu, bool push, const CandidateRelations& relations);
|
||||||
@@ -115,8 +128,7 @@ private:
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
void runDcp();
|
void runDcp();
|
||||||
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatCompute> spatComputes,
|
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatCompute> spatComputes, llvm::ArrayRef<IndexedEdge> edges)
|
||||||
llvm::ArrayRef<IndexedEdge> edges)
|
|
||||||
: nodes(), cpuTasks(), cpuCrossbarUsage() {
|
: nodes(), cpuTasks(), cpuCrossbarUsage() {
|
||||||
for (auto spatCompute : spatComputes)
|
for (auto spatCompute : spatComputes)
|
||||||
nodes.emplace_back(spatCompute);
|
nodes.emplace_back(spatCompute);
|
||||||
@@ -150,6 +162,11 @@ public:
|
|||||||
void setMaxCpuCount(int value) { maxCpuCount = value; }
|
void setMaxCpuCount(int value) { maxCpuCount = value; }
|
||||||
int getMaxCpuCount() const { return maxCpuCount; }
|
int getMaxCpuCount() const { return maxCpuCount; }
|
||||||
|
|
||||||
|
// Total crossbar units allocated across all active CPUs.
|
||||||
|
size_t crossbarsUsed() const;
|
||||||
|
// Maximum crossbar units available across all active CPUs (lastCpu * per-CPU capacity).
|
||||||
|
size_t crossbarsAvailable() const;
|
||||||
|
|
||||||
// Optional MLIR context used to drive mlir::parallelFor inside runDcp. If
|
// Optional MLIR context used to drive mlir::parallelFor inside runDcp. If
|
||||||
// null the scheduler runs single-threaded (tests use this path).
|
// null the scheduler runs single-threaded (tests use this path).
|
||||||
void setContext(mlir::MLIRContext* ctx) { context = ctx; }
|
void setContext(mlir::MLIRContext* ctx) { context = ctx; }
|
||||||
|
|||||||
@@ -35,10 +35,12 @@ void DcpProgressLogger::recordSelectDuration(double seconds) { selectProcessorSe
|
|||||||
void DcpProgressLogger::recordUpdateDuration(double seconds) { updateTimingSeconds += seconds; }
|
void DcpProgressLogger::recordUpdateDuration(double seconds) { updateTimingSeconds += seconds; }
|
||||||
void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += taskCount; }
|
void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += taskCount; }
|
||||||
|
|
||||||
void DcpProgressLogger::printStart(size_t readyCount) const {
|
void DcpProgressLogger::printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const {
|
||||||
if (!logProgress)
|
if (!logProgress)
|
||||||
return;
|
return;
|
||||||
llvm::errs() << llvm::formatv("[DCP] start: tasks={0} ready={1}\n", totalTasks, readyCount);
|
llvm::errs() << llvm::formatv(
|
||||||
|
"[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n",
|
||||||
|
totalTasks, readyCount, maxCpuCount, xbarsCapacity);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
|
void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
|
||||||
@@ -48,14 +50,15 @@ void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
|
|||||||
if (!logProgress || elapsedSeconds < 1.0)
|
if (!logProgress || elapsedSeconds < 1.0)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
llvm::errs() << llvm::formatv("[DCP] slow candidate node={0} elapsed={1} ready={2} cpus={3}\n",
|
llvm::errs() << llvm::formatv("[DCP] slow node={0} elapsed={1} ready={2} cpus={3}\n",
|
||||||
nodeIndex,
|
nodeIndex,
|
||||||
formatDuration(elapsedSeconds),
|
formatDuration(elapsedSeconds),
|
||||||
readyCount,
|
readyCount,
|
||||||
cpuCount);
|
cpuCount);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DcpProgressLogger::printProgress(size_t readyCount, CPU cpuCount, llvm::StringRef stage, bool force) {
|
void DcpProgressLogger::printProgress(
|
||||||
|
size_t readyCount, CPU cpuCount, int maxCpuCount, size_t xbarsUsed, size_t xbarsAvailable, bool force) {
|
||||||
if (!logProgress)
|
if (!logProgress)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
@@ -68,19 +71,19 @@ void DcpProgressLogger::printProgress(size_t readyCount, CPU cpuCount, llvm::Str
|
|||||||
double etaSeconds = rate > 0.0 ? static_cast<double>(totalTasks - completedTasks) / rate : 0.0;
|
double etaSeconds = rate > 0.0 ? static_cast<double>(totalTasks - completedTasks) / rate : 0.0;
|
||||||
double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks);
|
double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks);
|
||||||
|
|
||||||
llvm::errs() << llvm::formatv("[DCP] {0}/{1} ({2:F1}%) ready={3} cpus={4} stage={5} elapsed={6} eta={7}\n",
|
bool done = completedTasks == totalTasks;
|
||||||
completedTasks,
|
llvm::errs() << llvm::formatv(
|
||||||
totalTasks,
|
"[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n",
|
||||||
percent,
|
completedTasks,
|
||||||
readyCount,
|
totalTasks,
|
||||||
cpuCount,
|
percent,
|
||||||
stage,
|
readyCount,
|
||||||
formatDuration(elapsedSeconds),
|
cpuCount,
|
||||||
completedTasks == totalTasks ? "0:00" : formatDuration(etaSeconds));
|
maxCpuCount,
|
||||||
llvm::errs() << llvm::formatv(" time(find={0}, select={1}, update={2})\n",
|
xbarsUsed,
|
||||||
formatDuration(findCandidateSeconds),
|
xbarsAvailable,
|
||||||
formatDuration(selectProcessorSeconds),
|
llvm::formatv("elapsed={0}", formatDuration(elapsedSeconds)).str(),
|
||||||
formatDuration(updateTimingSeconds));
|
done ? "" : llvm::formatv(" eta={0}", formatDuration(etaSeconds)).str());
|
||||||
lastProgressPrint = now;
|
lastProgressPrint = now;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,9 +94,9 @@ void DcpProgressLogger::recordFindDuration(double) {}
|
|||||||
void DcpProgressLogger::recordSelectDuration(double) {}
|
void DcpProgressLogger::recordSelectDuration(double) {}
|
||||||
void DcpProgressLogger::recordUpdateDuration(double) {}
|
void DcpProgressLogger::recordUpdateDuration(double) {}
|
||||||
void DcpProgressLogger::advanceCompleted(size_t) {}
|
void DcpProgressLogger::advanceCompleted(size_t) {}
|
||||||
void DcpProgressLogger::printStart(size_t) const {}
|
void DcpProgressLogger::printStart(size_t, int, size_t) const {}
|
||||||
void DcpProgressLogger::maybePrintSlowCandidate(size_t, double, size_t, CPU) const {}
|
void DcpProgressLogger::maybePrintSlowCandidate(size_t, double, size_t, CPU) const {}
|
||||||
void DcpProgressLogger::printProgress(size_t, CPU, llvm::StringRef, bool) {}
|
void DcpProgressLogger::printProgress(size_t, CPU, int, size_t, size_t, bool) {}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|||||||
@@ -31,9 +31,10 @@ public:
|
|||||||
void recordUpdateDuration(double seconds);
|
void recordUpdateDuration(double seconds);
|
||||||
void advanceCompleted(size_t taskCount = 1);
|
void advanceCompleted(size_t taskCount = 1);
|
||||||
|
|
||||||
void printStart(size_t readyCount) const;
|
void printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const;
|
||||||
void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const;
|
void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const;
|
||||||
void printProgress(size_t readyCount, CPU cpuCount, llvm::StringRef stage, bool force);
|
void printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount,
|
||||||
|
size_t xbarsUsed, size_t xbarsAvailable, bool force);
|
||||||
|
|
||||||
#ifdef DCP_DEBUG_ENABLED
|
#ifdef DCP_DEBUG_ENABLED
|
||||||
private:
|
private:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
@@ -10,23 +11,30 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallSet.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <cstdlib>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "DCPGraph/DCPAnalysis.hpp"
|
#include "DCPGraph/DCPAnalysis.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -34,6 +42,541 @@ namespace onnx_mlir {
|
|||||||
namespace {
|
namespace {
|
||||||
using SpatCompute = spatial::SpatCompute;
|
using SpatCompute = spatial::SpatCompute;
|
||||||
|
|
||||||
|
SpatCompute getOriginalSpatCompute(Operation* op) {
|
||||||
|
while (auto extract = dyn_cast_if_present<tensor::ExtractSliceOp>(op))
|
||||||
|
op = extract.getSource().getDefiningOp();
|
||||||
|
return dyn_cast_if_present<SpatCompute>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ComputeMotifInfo {
|
||||||
|
uint64_t instructionCount = 0;
|
||||||
|
uint64_t weightedMvmCount = 0;
|
||||||
|
uint64_t weightedVmmCount = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
void appendUnique(SmallVector<size_t>& values, size_t value) {
|
||||||
|
if (!llvm::is_contained(values, value))
|
||||||
|
values.push_back(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isTrivialSerialMergeCandidate(SpatCompute compute) {
|
||||||
|
if (!compute->hasOneUse())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto& use = *compute->getUses().begin();
|
||||||
|
auto user = dyn_cast<SpatCompute>(use.getOwner());
|
||||||
|
return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SpatCompute target, ValueRange sourceWeights) {
|
||||||
|
DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices;
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(target.getWeights()))
|
||||||
|
targetWeightIndices[weight].push_back(weightIndex);
|
||||||
|
|
||||||
|
DenseMap<Value, size_t> usedSourceWeightOccurrences;
|
||||||
|
SmallVector<size_t> sourceToTargetIndex;
|
||||||
|
sourceToTargetIndex.reserve(sourceWeights.size());
|
||||||
|
auto targetWeights = target.getWeightsMutable();
|
||||||
|
for (Value weight : sourceWeights) {
|
||||||
|
size_t occurrence = usedSourceWeightOccurrences[weight]++;
|
||||||
|
auto& matchingIndices = targetWeightIndices[weight];
|
||||||
|
if (occurrence >= matchingIndices.size()) {
|
||||||
|
size_t newIndex = target.getWeights().size();
|
||||||
|
targetWeights.append(weight);
|
||||||
|
matchingIndices.push_back(newIndex);
|
||||||
|
sourceToTargetIndex.push_back(newIndex);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
sourceToTargetIndex.push_back(matchingIndices[occurrence]);
|
||||||
|
}
|
||||||
|
return sourceToTargetIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||||
|
Location loc = funcOp.getLoc();
|
||||||
|
IRRewriter rewriter(funcOp->getContext());
|
||||||
|
SmallVector<SpatCompute> trivialComputes;
|
||||||
|
llvm::SmallSet<SpatCompute, 8> toErase;
|
||||||
|
|
||||||
|
for (auto compute : funcOp.getOps<SpatCompute>())
|
||||||
|
if (isTrivialSerialMergeCandidate(compute))
|
||||||
|
trivialComputes.push_back(compute);
|
||||||
|
|
||||||
|
while (!trivialComputes.empty()) {
|
||||||
|
auto compute = trivialComputes.front();
|
||||||
|
|
||||||
|
if (compute.use_empty()) {
|
||||||
|
std::swap(trivialComputes.front(), trivialComputes.back());
|
||||||
|
trivialComputes.pop_back();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& computeUse = *compute->getUses().begin();
|
||||||
|
auto child = cast<SpatCompute>(computeUse.getOwner());
|
||||||
|
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
|
||||||
|
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||||
|
auto newCompute = SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
|
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
||||||
|
|
||||||
|
IRMapping mapper;
|
||||||
|
SmallVector<size_t> childWeightToNewIndex = appendMissingWeightsAndBuildIndexMap(newCompute, child.getWeights());
|
||||||
|
for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights()))
|
||||||
|
mapper.map(weight, *std::next(newCompute.getWeights().begin(), childWeightToNewIndex[oldIndex]));
|
||||||
|
|
||||||
|
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||||
|
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||||
|
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
|
||||||
|
newTerminator->erase();
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||||
|
auto remapWeightIndex = [&](auto weightedOp) {
|
||||||
|
auto oldIndex = weightedOp.getWeightIndex();
|
||||||
|
assert(static_cast<size_t>(oldIndex) < childWeightToNewIndex.size() && "weight index out of range");
|
||||||
|
weightedOp.setWeightIndex(childWeightToNewIndex[oldIndex]);
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto& op : child.getBody().front()) {
|
||||||
|
auto newInst = rewriter.clone(op, mapper);
|
||||||
|
if (auto weightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(newInst))
|
||||||
|
remapWeightIndex(weightedMvmOp);
|
||||||
|
if (auto weightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(newInst))
|
||||||
|
remapWeightIndex(weightedVmmOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
child.replaceAllUsesWith(newCompute);
|
||||||
|
toErase.insert(child);
|
||||||
|
|
||||||
|
std::swap(trivialComputes.front(), trivialComputes.back());
|
||||||
|
trivialComputes.pop_back();
|
||||||
|
toErase.insert(compute);
|
||||||
|
|
||||||
|
if (isTrivialSerialMergeCandidate(newCompute))
|
||||||
|
trivialComputes.push_back(newCompute);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto compute : toErase) {
|
||||||
|
for (Value result : compute->getResults())
|
||||||
|
result.dropAllUses();
|
||||||
|
compute.erase();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WeightedVmmBandCandidate {
|
||||||
|
Operation* parent;
|
||||||
|
SpatCompute compute;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool isSingleWeightedVmmCompute(SpatCompute compute) {
|
||||||
|
if (compute.getNumResults() != 1 || compute.getWeights().size() != 1 || compute.getInputs().size() != 1)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
uint64_t weightedVmmCount = 0;
|
||||||
|
for (Operation& op : compute.getBody().front()) {
|
||||||
|
if (isa<spatial::SpatWeightedMVMOp>(&op))
|
||||||
|
return false;
|
||||||
|
if (isa<spatial::SpatWeightedVMMOp>(&op))
|
||||||
|
weightedVmmCount++;
|
||||||
|
}
|
||||||
|
return weightedVmmCount == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<WeightedVmmBandCandidate> getWeightedVmmBandCandidate(SpatCompute compute) {
|
||||||
|
if (!isSingleWeightedVmmCompute(compute) || !compute->hasOneUse())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
auto& use = *compute->getUses().begin();
|
||||||
|
auto child = dyn_cast<SpatCompute>(use.getOwner());
|
||||||
|
if (!child || use.getOperandNumber() < child.getWeights().size())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
auto parent = getOriginalSpatCompute(compute.getInputs().front().getDefiningOp());
|
||||||
|
if (!parent || parent == child)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
return WeightedVmmBandCandidate {parent.getOperation(), compute};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool haveSameWeightedVmmBandShape(SpatCompute lhs, SpatCompute rhs) {
|
||||||
|
return lhs.getWeights().front().getType() == rhs.getWeights().front().getType()
|
||||||
|
&& lhs.getInputs().front().getType() == rhs.getInputs().front().getType()
|
||||||
|
&& lhs.getResult(0).getType() == rhs.getResult(0).getType();
|
||||||
|
}
|
||||||
|
|
||||||
|
SpatCompute packWeightedVmmComputes(func::FuncOp funcOp, ArrayRef<SpatCompute> computes) {
|
||||||
|
assert(!computes.empty() && "expected at least one compute to pack");
|
||||||
|
|
||||||
|
IRRewriter rewriter(funcOp->getContext());
|
||||||
|
SpatCompute child = cast<SpatCompute>((*computes.front()->getUses().begin()).getOwner());
|
||||||
|
rewriter.setInsertionPoint(child);
|
||||||
|
|
||||||
|
SmallVector<Value> operands;
|
||||||
|
SmallVector<Type> inputTypes;
|
||||||
|
SmallVector<Location> inputLocs;
|
||||||
|
SmallVector<Type> resultTypes;
|
||||||
|
operands.reserve(computes.size() * 2);
|
||||||
|
inputTypes.reserve(computes.size());
|
||||||
|
inputLocs.reserve(computes.size());
|
||||||
|
resultTypes.reserve(computes.size());
|
||||||
|
|
||||||
|
for (SpatCompute compute : computes)
|
||||||
|
for (Value weight : compute.getWeights())
|
||||||
|
operands.push_back(weight);
|
||||||
|
for (SpatCompute compute : computes) {
|
||||||
|
for (Value input : compute.getInputs()) {
|
||||||
|
operands.push_back(input);
|
||||||
|
inputTypes.push_back(input.getType());
|
||||||
|
inputLocs.push_back(input.getLoc());
|
||||||
|
}
|
||||||
|
for (Type resultType : compute.getResultTypes())
|
||||||
|
resultTypes.push_back(resultType);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto packedCompute = SpatCompute::create(rewriter, funcOp.getLoc(), resultTypes, operands);
|
||||||
|
packedCompute.getProperties().setOperandSegmentSizes(
|
||||||
|
{static_cast<int>(computes.size()), static_cast<int>(inputTypes.size())});
|
||||||
|
auto* block = rewriter.createBlock(&packedCompute.getBody(), packedCompute.getBody().end(), inputTypes, inputLocs);
|
||||||
|
rewriter.setInsertionPointToEnd(block);
|
||||||
|
|
||||||
|
SmallVector<Value> yieldValues;
|
||||||
|
yieldValues.reserve(resultTypes.size());
|
||||||
|
size_t inputBaseIndex = 0;
|
||||||
|
size_t weightBaseIndex = 0;
|
||||||
|
for (SpatCompute compute : computes) {
|
||||||
|
IRMapping mapper;
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
|
||||||
|
mapper.map(weight, *std::next(packedCompute.getWeights().begin(), weightBaseIndex + weightIndex));
|
||||||
|
for (auto [inputIndex, bbArg] : llvm::enumerate(compute.getBody().front().getArguments()))
|
||||||
|
mapper.map(bbArg, block->getArgument(inputBaseIndex + inputIndex));
|
||||||
|
|
||||||
|
auto remapWeightIndex = [&](auto weightedOp) {
|
||||||
|
weightedOp.setWeightIndex(weightBaseIndex + weightedOp.getWeightIndex());
|
||||||
|
};
|
||||||
|
for (Operation& op : compute.getBody().front()) {
|
||||||
|
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||||
|
for (Value yieldOperand : yield.getOperands())
|
||||||
|
yieldValues.push_back(mapper.lookup(yieldOperand));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* cloned = rewriter.clone(op, mapper);
|
||||||
|
if (auto weightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(cloned))
|
||||||
|
remapWeightIndex(weightedMvmOp);
|
||||||
|
if (auto weightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(cloned))
|
||||||
|
remapWeightIndex(weightedVmmOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
inputBaseIndex += compute.getInputs().size();
|
||||||
|
weightBaseIndex += compute.getWeights().size();
|
||||||
|
}
|
||||||
|
spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), yieldValues);
|
||||||
|
|
||||||
|
size_t resultIndex = 0;
|
||||||
|
for (SpatCompute compute : computes)
|
||||||
|
for (OpResult result : compute->getResults())
|
||||||
|
result.replaceAllUsesWith(packedCompute.getResult(resultIndex++));
|
||||||
|
for (SpatCompute compute : llvm::reverse(computes))
|
||||||
|
compute.erase();
|
||||||
|
|
||||||
|
return packedCompute;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getFastPathCpuBudget() {
|
||||||
|
constexpr size_t kDefaultMaxCpuCount = 1000;
|
||||||
|
if (coresCount.getValue() > 0)
|
||||||
|
return static_cast<size_t>(coresCount.getValue());
|
||||||
|
return kDefaultMaxCpuCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t packWideWeightedVmmBands(func::FuncOp funcOp) {
|
||||||
|
constexpr size_t kMinBandSize = 64;
|
||||||
|
size_t cpuBudget = std::max<size_t>(1, getFastPathCpuBudget());
|
||||||
|
SmallVector<WeightedVmmBandCandidate> candidates;
|
||||||
|
for (SpatCompute compute : funcOp.getOps<SpatCompute>())
|
||||||
|
if (auto candidate = getWeightedVmmBandCandidate(compute))
|
||||||
|
candidates.push_back(*candidate);
|
||||||
|
|
||||||
|
llvm::stable_sort(candidates, [](const WeightedVmmBandCandidate& lhs, const WeightedVmmBandCandidate& rhs) {
|
||||||
|
if (lhs.parent != rhs.parent)
|
||||||
|
return lhs.parent < rhs.parent;
|
||||||
|
return lhs.compute->isBeforeInBlock(rhs.compute);
|
||||||
|
});
|
||||||
|
|
||||||
|
size_t packedBandCount = 0;
|
||||||
|
size_t packedNodeCount = 0;
|
||||||
|
size_t createdNodeCount = 0;
|
||||||
|
size_t minChunkSize = std::numeric_limits<size_t>::max();
|
||||||
|
size_t maxChunkSize = 0;
|
||||||
|
SmallVector<SmallVector<SpatCompute>> shapeGroups;
|
||||||
|
for (size_t parentBegin = 0; parentBegin < candidates.size();) {
|
||||||
|
size_t parentEnd = parentBegin + 1;
|
||||||
|
while (parentEnd < candidates.size() && candidates[parentEnd].parent == candidates[parentBegin].parent)
|
||||||
|
parentEnd++;
|
||||||
|
|
||||||
|
shapeGroups.clear();
|
||||||
|
for (size_t index = parentBegin; index < parentEnd; ++index) {
|
||||||
|
SpatCompute compute = candidates[index].compute;
|
||||||
|
auto groupIt = llvm::find_if(shapeGroups, [&](const SmallVector<SpatCompute>& group) {
|
||||||
|
return haveSameWeightedVmmBandShape(group.front(), compute);
|
||||||
|
});
|
||||||
|
if (groupIt == shapeGroups.end())
|
||||||
|
shapeGroups.push_back({compute});
|
||||||
|
else
|
||||||
|
groupIt->push_back(compute);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (ArrayRef<SpatCompute> band : shapeGroups) {
|
||||||
|
size_t bandSize = band.size();
|
||||||
|
if (bandSize < kMinBandSize || bandSize <= cpuBudget)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
size_t chunkSize = (bandSize + cpuBudget - 1) / cpuBudget;
|
||||||
|
packedBandCount++;
|
||||||
|
SmallVector<SpatCompute> chunk;
|
||||||
|
chunk.reserve(std::min(chunkSize, bandSize));
|
||||||
|
for (auto [index, compute] : llvm::enumerate(band)) {
|
||||||
|
chunk.push_back(compute);
|
||||||
|
if (chunk.size() == chunkSize || index + 1 == band.size()) {
|
||||||
|
minChunkSize = std::min(minChunkSize, chunk.size());
|
||||||
|
maxChunkSize = std::max(maxChunkSize, chunk.size());
|
||||||
|
packWeightedVmmComputes(funcOp, chunk);
|
||||||
|
packedNodeCount += chunk.size();
|
||||||
|
createdNodeCount++;
|
||||||
|
chunk.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parentBegin = parentEnd;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (packedNodeCount != 0)
|
||||||
|
llvm::errs() << llvm::formatv("[DCP-FASTPATH] wvmmBands={0} packedNodes={1} createdNodes={2} changed={3} "
|
||||||
|
"cpuBudget={4} chunkSizeRange={5}-{6}\n",
|
||||||
|
packedBandCount,
|
||||||
|
packedNodeCount,
|
||||||
|
createdNodeCount,
|
||||||
|
packedNodeCount - createdNodeCount,
|
||||||
|
cpuBudget,
|
||||||
|
minChunkSize,
|
||||||
|
maxChunkSize);
|
||||||
|
return packedNodeCount - createdNodeCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
void emitMotifProfile(func::FuncOp funcOp) {
|
||||||
|
if (!std::getenv("DCP_MOTIF_PROFILE"))
|
||||||
|
return;
|
||||||
|
|
||||||
|
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||||
|
DenseMap<SpatCompute, size_t> computeToIndex;
|
||||||
|
computeToIndex.reserve(computes.size());
|
||||||
|
for (auto [index, compute] : llvm::enumerate(computes))
|
||||||
|
computeToIndex[compute] = index;
|
||||||
|
|
||||||
|
SmallVector<ComputeMotifInfo> computeInfos(computes.size());
|
||||||
|
SmallVector<SmallVector<size_t>> parents(computes.size());
|
||||||
|
SmallVector<SmallVector<size_t>> children(computes.size());
|
||||||
|
|
||||||
|
uint64_t weightedVmmNodeCount = 0;
|
||||||
|
uint64_t weightedVmmOpCount = 0;
|
||||||
|
uint64_t edgeCount = 0;
|
||||||
|
|
||||||
|
for (auto [index, compute] : llvm::enumerate(computes)) {
|
||||||
|
ComputeMotifInfo& info = computeInfos[index];
|
||||||
|
for (Operation& op : compute.getBody().front()) {
|
||||||
|
info.instructionCount++;
|
||||||
|
if (isa<spatial::SpatWeightedMVMOp>(&op))
|
||||||
|
info.weightedMvmCount++;
|
||||||
|
if (isa<spatial::SpatWeightedVMMOp>(&op))
|
||||||
|
info.weightedVmmCount++;
|
||||||
|
}
|
||||||
|
if (info.weightedVmmCount > 0) {
|
||||||
|
weightedVmmNodeCount++;
|
||||||
|
weightedVmmOpCount += info.weightedVmmCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Value input : compute.getInputs()) {
|
||||||
|
auto parent = getOriginalSpatCompute(input.getDefiningOp());
|
||||||
|
if (!parent || parent == compute)
|
||||||
|
continue;
|
||||||
|
auto parentIt = computeToIndex.find(parent);
|
||||||
|
if (parentIt == computeToIndex.end())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
size_t parentIndex = parentIt->second;
|
||||||
|
size_t oldParentCount = parents[index].size();
|
||||||
|
appendUnique(parents[index], parentIndex);
|
||||||
|
if (parents[index].size() != oldParentCount) {
|
||||||
|
appendUnique(children[parentIndex], index);
|
||||||
|
edgeCount++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t maxFanIn = 0;
|
||||||
|
uint64_t maxFanOut = 0;
|
||||||
|
uint64_t fanIn16 = 0;
|
||||||
|
uint64_t fanIn64 = 0;
|
||||||
|
uint64_t fanIn256 = 0;
|
||||||
|
uint64_t fanOut16 = 0;
|
||||||
|
uint64_t fanOut64 = 0;
|
||||||
|
uint64_t fanOut256 = 0;
|
||||||
|
for (size_t index = 0; index < computes.size(); ++index) {
|
||||||
|
uint64_t fanIn = parents[index].size();
|
||||||
|
uint64_t fanOut = children[index].size();
|
||||||
|
maxFanIn = std::max(maxFanIn, fanIn);
|
||||||
|
maxFanOut = std::max(maxFanOut, fanOut);
|
||||||
|
fanIn16 += fanIn >= 16;
|
||||||
|
fanIn64 += fanIn >= 64;
|
||||||
|
fanIn256 += fanIn >= 256;
|
||||||
|
fanOut16 += fanOut >= 16;
|
||||||
|
fanOut64 += fanOut >= 64;
|
||||||
|
fanOut256 += fanOut >= 256;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t serialChainCount = 0;
|
||||||
|
uint64_t serialChainNodeCount = 0;
|
||||||
|
uint64_t maxSerialChain = 0;
|
||||||
|
for (size_t index = 0; index < computes.size(); ++index) {
|
||||||
|
if (parents[index].size() == 1 && children[parents[index][0]].size() == 1)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
uint64_t chainLength = 1;
|
||||||
|
size_t current = index;
|
||||||
|
while (children[current].size() == 1) {
|
||||||
|
size_t child = children[current][0];
|
||||||
|
if (parents[child].size() != 1)
|
||||||
|
break;
|
||||||
|
chainLength++;
|
||||||
|
current = child;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chainLength >= 2) {
|
||||||
|
serialChainCount++;
|
||||||
|
serialChainNodeCount += chainLength;
|
||||||
|
maxSerialChain = std::max(maxSerialChain, chainLength);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<size_t> incomingEdgeCount;
|
||||||
|
incomingEdgeCount.reserve(parents.size());
|
||||||
|
for (ArrayRef<size_t> parentList : parents)
|
||||||
|
incomingEdgeCount.push_back(parentList.size());
|
||||||
|
|
||||||
|
SmallVector<uint64_t> level(computes.size(), 0);
|
||||||
|
SmallVector<size_t> readyNodes;
|
||||||
|
readyNodes.reserve(computes.size());
|
||||||
|
for (size_t index = 0; index < computes.size(); ++index)
|
||||||
|
if (incomingEdgeCount[index] == 0)
|
||||||
|
readyNodes.push_back(index);
|
||||||
|
|
||||||
|
size_t readyIndex = 0;
|
||||||
|
while (readyIndex != readyNodes.size()) {
|
||||||
|
size_t current = readyNodes[readyIndex++];
|
||||||
|
for (size_t child : children[current]) {
|
||||||
|
level[child] = std::max(level[child], level[current] + 1);
|
||||||
|
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
||||||
|
incomingEdgeCount[child]--;
|
||||||
|
if (incomingEdgeCount[child] == 0)
|
||||||
|
readyNodes.push_back(child);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<uint64_t> weightedVmmNodesByLevel;
|
||||||
|
for (size_t index = 0; index < computes.size(); ++index) {
|
||||||
|
if (computeInfos[index].weightedVmmCount == 0)
|
||||||
|
continue;
|
||||||
|
if (weightedVmmNodesByLevel.size() <= level[index])
|
||||||
|
weightedVmmNodesByLevel.resize(level[index] + 1, 0);
|
||||||
|
weightedVmmNodesByLevel[level[index]]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t maxWeightedVmmLevel = 0;
|
||||||
|
uint64_t wideWeightedVmmLevels64 = 0;
|
||||||
|
uint64_t wideWeightedVmmLevels256 = 0;
|
||||||
|
for (uint64_t count : weightedVmmNodesByLevel) {
|
||||||
|
maxWeightedVmmLevel = std::max(maxWeightedVmmLevel, count);
|
||||||
|
wideWeightedVmmLevels64 += count >= 64;
|
||||||
|
wideWeightedVmmLevels256 += count >= 256;
|
||||||
|
}
|
||||||
|
|
||||||
|
using ShapeKey = std::tuple<uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t>;
|
||||||
|
SmallVector<ShapeKey> weightedVmmShapeKeys;
|
||||||
|
for (auto [index, compute] : llvm::enumerate(computes)) {
|
||||||
|
const ComputeMotifInfo& info = computeInfos[index];
|
||||||
|
if (info.weightedVmmCount == 0)
|
||||||
|
continue;
|
||||||
|
weightedVmmShapeKeys.push_back({info.instructionCount,
|
||||||
|
info.weightedVmmCount,
|
||||||
|
info.weightedMvmCount,
|
||||||
|
static_cast<uint64_t>(compute.getWeights().size()),
|
||||||
|
static_cast<uint64_t>(compute.getInputs().size()),
|
||||||
|
static_cast<uint64_t>(parents[index].size()),
|
||||||
|
static_cast<uint64_t>(children[index].size())});
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::sort(weightedVmmShapeKeys);
|
||||||
|
SmallVector<std::pair<uint64_t, ShapeKey>> weightedVmmShapeCounts;
|
||||||
|
for (size_t index = 0; index < weightedVmmShapeKeys.size();) {
|
||||||
|
size_t next = index + 1;
|
||||||
|
while (next < weightedVmmShapeKeys.size() && weightedVmmShapeKeys[next] == weightedVmmShapeKeys[index])
|
||||||
|
next++;
|
||||||
|
weightedVmmShapeCounts.push_back({next - index, weightedVmmShapeKeys[index]});
|
||||||
|
index = next;
|
||||||
|
}
|
||||||
|
llvm::sort(weightedVmmShapeCounts, [](const auto& lhs, const auto& rhs) {
|
||||||
|
if (lhs.first != rhs.first)
|
||||||
|
return lhs.first > rhs.first;
|
||||||
|
return lhs.second < rhs.second;
|
||||||
|
});
|
||||||
|
|
||||||
|
llvm::errs() << llvm::formatv("[DCP-MOTIF] computes={0} edges={1} wvmmNodes={2} wvmmOps={3} "
|
||||||
|
"serialChains={4} serialChainNodes={5} maxSerialChain={6} "
|
||||||
|
"maxFanIn={7} maxFanOut={8} fanIn>=16/64/256={9}/{10}/{11} "
|
||||||
|
"fanOut>=16/64/256={12}/{13}/{14} topoVisited={15}\n",
|
||||||
|
computes.size(),
|
||||||
|
edgeCount,
|
||||||
|
weightedVmmNodeCount,
|
||||||
|
weightedVmmOpCount,
|
||||||
|
serialChainCount,
|
||||||
|
serialChainNodeCount,
|
||||||
|
maxSerialChain,
|
||||||
|
maxFanIn,
|
||||||
|
maxFanOut,
|
||||||
|
fanIn16,
|
||||||
|
fanIn64,
|
||||||
|
fanIn256,
|
||||||
|
fanOut16,
|
||||||
|
fanOut64,
|
||||||
|
fanOut256,
|
||||||
|
readyNodes.size());
|
||||||
|
|
||||||
|
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmLevels={0} maxWvmmLevel={1} wideWvmmLevels>=64/256={2}/{3} "
|
||||||
|
"shapeGroups={4}\n",
|
||||||
|
weightedVmmNodesByLevel.size(),
|
||||||
|
maxWeightedVmmLevel,
|
||||||
|
wideWeightedVmmLevels64,
|
||||||
|
wideWeightedVmmLevels256,
|
||||||
|
weightedVmmShapeCounts.size());
|
||||||
|
|
||||||
|
for (size_t rank = 0, end = std::min<size_t>(weightedVmmShapeCounts.size(), 5); rank < end; ++rank) {
|
||||||
|
auto [count, shape] = weightedVmmShapeCounts[rank];
|
||||||
|
auto [insts, vmmOps, mvmOps, weights, inputs, fanIn, fanOut] = shape;
|
||||||
|
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmShape rank={0} count={1} insts={2} vmmOps={3} "
|
||||||
|
"mvmOps={4} weights={5} inputs={6} fanIn={7} fanOut={8}\n",
|
||||||
|
rank,
|
||||||
|
count,
|
||||||
|
insts,
|
||||||
|
vmmOps,
|
||||||
|
mvmOps,
|
||||||
|
weights,
|
||||||
|
inputs,
|
||||||
|
fanIn,
|
||||||
|
fanOut);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void generateReport(func::FuncOp funcOp, const std::string& name) {
|
void generateReport(func::FuncOp funcOp, const std::string& name) {
|
||||||
std::string outputDir = getOutputDir();
|
std::string outputDir = getOutputDir();
|
||||||
if (outputDir.empty())
|
if (outputDir.empty())
|
||||||
@@ -213,6 +756,10 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
|
mergeTriviallyConnectedComputes(getOperation());
|
||||||
|
packWideWeightedVmmBands(getOperation());
|
||||||
|
emitMotifProfile(getOperation());
|
||||||
|
|
||||||
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
|
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
|
||||||
auto& lastComputeOfCpu = analysisResult.isLastComputeOfCpu;
|
auto& lastComputeOfCpu = analysisResult.isLastComputeOfCpu;
|
||||||
auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap;
|
auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap;
|
||||||
@@ -342,22 +889,34 @@ private:
|
|||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
|
|
||||||
|
DenseMap<Value, SmallVector<size_t, 4>> toWeightIndices;
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(toCompute.getWeights()))
|
||||||
|
toWeightIndices[weight].push_back(weightIndex);
|
||||||
|
DenseMap<Value, size_t> usedFromWeightOccurrences;
|
||||||
|
SmallVector<size_t> fromWeightToNewIndex;
|
||||||
|
fromWeightToNewIndex.reserve(fromCompute.getWeights().size());
|
||||||
|
|
||||||
auto weightMutableIter = toCompute.getWeightsMutable();
|
auto weightMutableIter = toCompute.getWeightsMutable();
|
||||||
for (auto weight : fromCompute.getWeights()) {
|
for (auto weight : fromCompute.getWeights()) {
|
||||||
auto founded = llvm::find(toCompute.getWeights(), weight);
|
size_t occurrence = usedFromWeightOccurrences[weight]++;
|
||||||
if (founded == toCompute.getWeights().end()) {
|
auto& matchingIndices = toWeightIndices[weight];
|
||||||
|
if (occurrence >= matchingIndices.size()) {
|
||||||
size_t sizeW = toCompute.getWeights().size();
|
size_t sizeW = toCompute.getWeights().size();
|
||||||
size_t sizeI = toCompute.getInputs().size();
|
size_t sizeI = toCompute.getInputs().size();
|
||||||
weightMutableIter.append(weight);
|
weightMutableIter.append(weight);
|
||||||
auto last = weightMutableIter.end();
|
auto last = weightMutableIter.end();
|
||||||
last = std::prev(last, 1);
|
last = std::prev(last, 1);
|
||||||
mapper.map(weight, last->get());
|
mapper.map(weight, last->get());
|
||||||
|
matchingIndices.push_back(sizeW);
|
||||||
|
fromWeightToNewIndex.push_back(sizeW);
|
||||||
assert(sizeW + 1 == toCompute.getWeights().size());
|
assert(sizeW + 1 == toCompute.getWeights().size());
|
||||||
assert(sizeI == toCompute.getInputs().size());
|
assert(sizeI == toCompute.getInputs().size());
|
||||||
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
|
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
mapper.map(weight, *founded);
|
size_t newIndex = matchingIndices[occurrence];
|
||||||
|
mapper.map(weight, *std::next(toCompute.getWeights().begin(), newIndex));
|
||||||
|
fromWeightToNewIndex.push_back(newIndex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -409,9 +968,8 @@ private:
|
|||||||
ComputeValueResults computeValueResults;
|
ComputeValueResults computeValueResults;
|
||||||
auto remapWeightIndex = [&](auto weightedOp) {
|
auto remapWeightIndex = [&](auto weightedOp) {
|
||||||
auto oldIndex = weightedOp.getWeightIndex();
|
auto oldIndex = weightedOp.getWeightIndex();
|
||||||
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
|
assert(static_cast<size_t>(oldIndex) < fromWeightToNewIndex.size() && "weight index out of range");
|
||||||
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
|
weightedOp.setWeightIndex(fromWeightToNewIndex[oldIndex]);
|
||||||
weightedOp.setWeightIndex(newIndex);
|
|
||||||
};
|
};
|
||||||
for (auto& op : fromCompute.getOps()) {
|
for (auto& op : fromCompute.getOps()) {
|
||||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||||
|
|||||||
@@ -116,10 +116,9 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
|
|||||||
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
|
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
|
||||||
|
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPoint(coreOp);
|
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
|
||||||
|
|
||||||
rewriter.setInsertionPoint(mapOp);
|
rewriter.setInsertionPoint(mapOp);
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
|
||||||
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
|
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
|
||||||
pim::PimMemCopyOp::create(rewriter,
|
pim::PimMemCopyOp::create(rewriter,
|
||||||
mapOp.getLoc(),
|
mapOp.getLoc(),
|
||||||
|
|||||||
@@ -457,6 +457,10 @@ int testDCPGraphDiamondDependencies() {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// crossbarSize=4, crossbarCount=2 => capacity = 4*4*2 = 32.
|
||||||
|
// Each task with crossbarUsage=1 needs footprint = 4*4 = 16, so at most 1 task
|
||||||
|
// can fit per CPU (16+16 = 32 >= capacity). The scheduler must open a fresh CPU
|
||||||
|
// for each task; all three end up on separate CPUs with their base weight.
|
||||||
int testDCPGraphCrossbarExhaustion() {
|
int testDCPGraphCrossbarExhaustion() {
|
||||||
std::cout << "testDCPGraphCrossbarExhaustion:" << std::endl;
|
std::cout << "testDCPGraphCrossbarExhaustion:" << std::endl;
|
||||||
configureDcpDotOutput();
|
configureDcpDotOutput();
|
||||||
@@ -474,36 +478,35 @@ int testDCPGraphCrossbarExhaustion() {
|
|||||||
const std::vector<Weight> nodeWeights = {10, 10, 10};
|
const std::vector<Weight> nodeWeights = {10, 10, 10};
|
||||||
const std::vector<CrossbarUsage> nodeCrossbarUsage = {1, 1, 1};
|
const std::vector<CrossbarUsage> nodeCrossbarUsage = {1, 1, 1};
|
||||||
GraphDCP graph(nodeWeights, {}, nodeCrossbarUsage);
|
GraphDCP graph(nodeWeights, {}, nodeCrossbarUsage);
|
||||||
graph.setMaxCpuCount(1);
|
graph.setMaxCpuCount(3);
|
||||||
graph.runDcp();
|
graph.runDcp();
|
||||||
|
|
||||||
if (graph.cpuCount() != 1) {
|
if (graph.cpuCount() != 3) {
|
||||||
restoreCrossbarOptions();
|
restoreCrossbarOptions();
|
||||||
std::cerr << "Expected exactly 1 CPU with maxCpuCount=1, got " << graph.cpuCount() << "\n";
|
std::cerr << "Expected 3 CPUs (one per task due to crossbar limit), got " << graph.cpuCount() << "\n";
|
||||||
dumpDcpFailureArtifacts();
|
dumpDcpFailureArtifacts();
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto scheduledTasks = graph.getScheduledTasks(0);
|
int failures = 0;
|
||||||
if (scheduledTasks.size() != 3) {
|
for (CPU c = 0; c < 3; c++) {
|
||||||
restoreCrossbarOptions();
|
auto scheduledTasks = graph.getScheduledTasks(c);
|
||||||
std::cerr << "Expected all three tasks to be scheduled on CPU 0\n";
|
if (scheduledTasks.size() != 1) {
|
||||||
printCpuSchedule(graph, 0);
|
std::cerr << "Expected exactly 1 task on CPU " << c << ", got " << scheduledTasks.size() << "\n";
|
||||||
dumpDcpFailureArtifacts();
|
printCpuSchedule(graph, c);
|
||||||
return 1;
|
failures++;
|
||||||
}
|
continue;
|
||||||
|
}
|
||||||
if (scheduledTasks[0].weight != 10 || scheduledTasks[1].weight != std::numeric_limits<Weight>::max()
|
if (scheduledTasks[0].weight != 10) {
|
||||||
|| scheduledTasks[2].weight != std::numeric_limits<Weight>::max()) {
|
std::cerr << "Expected weight=10 on CPU " << c << ", got " << scheduledTasks[0].weight << "\n";
|
||||||
restoreCrossbarOptions();
|
printCpuSchedule(graph, c);
|
||||||
std::cerr << "Unexpected effective weights under crossbar exhaustion\n";
|
failures++;
|
||||||
printCpuSchedule(graph, 0);
|
}
|
||||||
dumpDcpFailureArtifacts();
|
|
||||||
return 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
restoreCrossbarOptions();
|
restoreCrossbarOptions();
|
||||||
return 0;
|
if (failures) dumpDcpFailureArtifacts();
|
||||||
|
return failures;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|||||||
Reference in New Issue
Block a user