1 Commits

Author SHA1 Message Date
NiccoloN
87922d994f multiple-output spat computes
All checks were successful
Validate Operations / validate-operations (push) Successful in 1h2m3s
2026-04-22 18:29:06 +02:00
28 changed files with 669 additions and 2538 deletions

154
README.md
View File

@@ -1,159 +1,5 @@
# Raptor
Raptor is a domain-specific MLIR compiler for neural networks (ONNX format)
targeting in-memory computing / processing-in-memory (PIM) architectures.
It progressively lowers ONNX-MLIR through a set of MLIR dialects down to
target-specific artifacts (currently JSON code for the `pimsim-nn` simulator).
## Overview
PIM architectures perform most of the computation directly in memory.
Raptor's first supported target is `pimsim-nn`, which simulates a chip with:
- a shared host memory,
- a number of cores that do most of the computation directly in their memory
(vector ops, vmm/mvm on ReRAM crossbars),
- no branching instructions (branchless architecture) and no hardware loop
support — any repeated work (e.g. convolutions) must be unrolled into
explicit per-iteration instructions.
Because of this, the amount of emitted instructions explodes quickly and the
compiler must optimize aggressively at every stage to keep compilation
tractable.
A second target, `PulPim`, is planned for an accelerator with RISC-V cores
each carrying its own in-memory computing unit and crossbars. It will live in
a dedicated dialect (future work).
### Targets and simulators
`pimsim-nn` (under `backend-simulators/pim/pimsim-nn`) is used for
**performance** estimates (latency, energy), but does not functionally execute
the JSON code it consumes. To validate the numerical correctness of the JSON
code produced by Raptor (or, for comparison, by the `pimcomp` compiler), we use
a Rust simulator we maintain in-tree at
`backend-simulators/pim/pim-simulator`.
## Compilation pipeline
The PIM-related sources live under `src/PIM` and the tests under `test/PIM`.
When working on this codebase, most changes should stay confined to those
trees (you only need to look outside, e.g. at `onnx-mlir` or `llvm`, for
framework-level details).
High-level lowering flow:
```
ONNX-MLIR ──► Spatial ──► Pim (tensor) ──► Pim (bufferized) ──► PIM JSON
```
1. **ONNX → Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
Lowers ONNX ops into the `spat` dialect (`src/PIM/Dialect/Spatial`).
Spatial models a high-level spatial in-memory accelerator: vmm/mvm
operations are accelerated by storing a constant RHS matrix into a
crossbar. Crossbars cannot be re-programmed during execution, have a
limited fixed size, and there is a limited number of them per core.
Conversion patterns are split by op family under
`Conversion/ONNXToSpatial/Patterns/{Math,NN,Tensor}` (Conv, Gemm, MatMul,
Elementwise, ReduceMean, Pool, Relu, Sigmoid, Softmax, Concat, Gather,
Reshape, Resize, Split).
2. **Spatial → Pim** (`src/PIM/Conversion/SpatialToPim`).
Lowers Spatial to the `pim` dialect (`src/PIM/Dialect/Pim`), which
materializes PIM cores (`pim.core`), inter-core communication
(`pim.send` / `pim.receive`), halts, and crossbar-level operations.
3. **Merge compute nodes** (`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
A DCP-inspired heuristic (Dynamic Critical Path — see the original
scheduling paper by Kwok & Ahmad,
[DCP-eScience2007](https://clouds.cis.unimelb.edu.au/papers/DCP-eScience2007.pdf))
that coarsens the virtual node graph and decides how to group compute
nodes onto cores. Our implementation is only DCP-*inspired*: it is a
heuristic with different assumptions from the paper (different cost
model, constraints from crossbar capacity / core resources, and a
windowed coarsening loop instead of full-graph reprioritization). The
`dcp-critical-window-size` option controls how many lowest-slack virtual
nodes each coarsening iteration considers (0 = legacy full-graph
analysis). Related sources: `DCPGraph/DCPAnalysis.cpp`, `Graph.cpp/.hpp`,
`MergeComputeNodesPass.cpp`.
4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`).
Converts tensor-semantics PIM IR into memref-semantics PIM IR using the
standard MLIR `BufferizableOpInterface` machinery
(`OpBufferizationInterfaces.*`, `PimBufferization.td`).
5. **PIM code generation** (`src/PIM/Pass/PimCodegen`):
- `HostConstantFolding` — folds host-side constants.
- `MaterializeHostConstantsPass` — materializes the remaining host
constants for emission.
- `VerificationPass` — checks invariants before emission.
- `EmitPimJsonPass` — emits the final PIM JSON consumed by `pimsim-nn`
and `pim-simulator`.
Supporting pieces:
- `src/PIM/Compiler` — PIM-specific compiler options (crossbar size/count,
core count, DCP window, experimental conv impl, concat error handling, …)
and `PimCodeGen` entry points.
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`).
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`)
and the `PIMPasses.h` registry used by `PimAccelerator`.
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers
dialects, passes, and plugs Raptor into the ONNX-MLIR driver.
## Key compiler options
Pass these on the `onnx-mlir` command line when compiling for PIM:
- `--maccel=PIM` — select the PIM accelerator.
- `--EmitSpatial` / `--EmitPim` / `--EmitPimBufferized` / `--EmitPimCodegen`
— stop the pipeline at the requested stage (default: `EmitPimCodegen`).
- `--pim-only-codegen` — assume the input is already bufferized PIM IR and
run only the codegen tail.
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
per-core count.
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
- `--use-experimental-conv-impl` — alternative convolution lowering.
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
## Validation
Functional validation lives in `validation/` and drives the Rust
`pim-simulator` to compare Raptor's output against a reference.
Per-operation validation (from `validation/`):
```
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include
```
End-to-end network validation (example: first 4 layers of YOLOv11n):
```
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include \
--operations-dir ./networks/yolo11n/depth_04 \
--crossbar-size 2048
```
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
Available operations under `validation/operations/`: `add`, `conv`, `div`,
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
`sigmoid`, `softmax`, `split`.
## Rebuilding
Release build (fast):
```
cmake --build /home/nico/raptor/raptor/cmake-build-release --target onnx-mlir -j 30
```
A slower debug build is also available — configure it the same way but with
`-DCMAKE_BUILD_TYPE=Debug` (see installation instructions below).
## Build
### Protobuf

View File

@@ -55,23 +55,15 @@ pub trait HasSigm {
impl HasSigm for f32 {
fn sigm(self) -> Self {
if self >= 0.0 {
1.0 / (1.0 + (-self).exp())
} else {
let ex = self.exp();
ex / (1.0 + ex)
}
let ex = self.exp();
ex / (1.0 + ex)
}
}
impl HasSigm for f64 {
fn sigm(self) -> Self {
if self >= 0.0 {
1.0 / (1.0 + (-self).exp())
} else {
let ex = self.exp();
ex / (1.0 + ex)
}
let ex = self.exp();
ex / (1.0 + ex)
}
}

View File

@@ -1,13 +1,9 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
@@ -100,82 +96,6 @@ void markWeightAlways(Operation* op) {
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
}
namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
bool found = false;
parentOp.walk([&](Operation* op) {
if (auto mvmOp = dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeightIndex() == weightIndex;
else if (auto vmmOp = dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeightIndex() == weightIndex;
});
return found;
}
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, function_ref<void(OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeightIndex = [&](unsigned weightIndex) {
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
};
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
}
} // namespace
bool isSpatialMvmVmmWeightUse(OpOperand& use) {
Operation* user = use.getOwner();
unsigned operandIndex = use.getOperandNumber();
auto computeOp = dyn_cast<spatial::SpatCompute>(user);
if (!computeOp || operandIndex >= computeOp.getWeights().size())
return false;
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
}
bool hasOnlySpatialMvmVmmWeightUses(Value value) {
SmallPtrSet<Value, 8> visited;
auto walkUses = [&](Value currentValue, auto& self) -> bool {
if (!visited.insert(currentValue).second)
return true;
if (currentValue.use_empty())
return false;
return llvm::all_of(currentValue.getUses(), [&](OpOperand& use) {
if (isSpatialMvmVmmWeightUse(use))
return true;
Operation* user = use.getOwner();
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user))
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(user))
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(user))
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(user))
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
return false;
});
};
return walkUses(value, walkUses);
}
void walkPimMvmVmmWeightUses(Operation* root, function_ref<void(OpOperand&)> callback) {
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) {
walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback);
});
}
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};

View File

@@ -7,7 +7,6 @@
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
@@ -41,11 +40,6 @@ bool hasWeightAlways(mlir::Operation* op);
void markWeightAlways(mlir::Operation* op);
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
llvm::FailureOr<mlir::Operation*>

View File

@@ -3,11 +3,9 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h"
@@ -55,23 +53,9 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
SmallVector<mlir::Value> args;
for (mlir::Value arg : funcOp.getArguments()){
gatherMemEntry(arg);
args.push_back(arg);
}
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (!hasWeightAlways(getGlobalOp)) {
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (globalMemrefOp.getName().starts_with("arg")){
StringRef indexStr = globalMemrefOp.getName().substr(4);
int index = 0;
llvm::to_integer(indexStr,index, 10);
globalAliases.push_back({getGlobalOp.getResult(), args[index]});
}
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
if (inserted)
gatherMemEntry(getGlobalOp.getResult());
@@ -80,6 +64,8 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
}
});
for (mlir::Value arg : funcOp.getArguments())
gatherMemEntry(arg);
funcOp.walk([&](memref::AllocOp allocOp) {
if (!allocOp->getParentOfType<pim::PimCoreOp>())
@@ -426,9 +412,6 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
emitInstruction(std::move(json));
}
void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const {
}
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
@@ -598,8 +581,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
else {
op.emitError("Unsupported codegen for this operation");
op.dump();

View File

@@ -106,7 +106,6 @@ public:
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
};

View File

@@ -41,7 +41,7 @@ llvm::cl::opt<size_t>
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
llvm::cl::opt<size_t>
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2));
llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
@@ -51,7 +51,7 @@ llvm::cl::opt<size_t> dcpCriticalWindowSize(
"dcp-critical-window-size",
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
"Use 0 to run the legacy full-graph DCP analysis."),
llvm::cl::init(4000));
llvm::cl::init(1024));
llvm::cl::opt<bool>
ignoreConcatError("ignore-concat-error",

View File

@@ -1,4 +1,3 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -9,10 +8,10 @@
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
@@ -25,6 +24,8 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -51,6 +52,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
void encapsulateGlobalInstruction(func::FuncOp funcOp);
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
};
@@ -85,7 +87,8 @@ void ONNXToSpatialPass::runOnOperation() {
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
target.addIllegalOp<ONNXMatMulOp>();
target.addDynamicallyLegalOp<ONNXMatMulOp>(
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>();
@@ -146,7 +149,6 @@ void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
annotateWeightsConstants(*entryFunc);
encapsulateGlobalInstruction(*entryFunc);
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
@@ -154,6 +156,8 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
mergeTriviallyConnectedComputes(*entryFunc);
// Dump to file for debug
dumpModule(moduleOp, "spatial0");
}
@@ -163,36 +167,19 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
Value source = funcSource(toRemoveOp);
rewriter.setInsertionPointAfter(toRemoveOp);
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
return false;
}
bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) {
if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present<tensor::ExtractSliceOp>(inst)) {
auto source = toRemoveOp.getSource();
rewriter.setInsertionPointAfter(toRemoveOp);
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
}
return false;
}
@@ -201,24 +188,27 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
auto sources = toRemoveOp.getInputs();
rewriter.setInsertionPointAfter(toRemoveOp);
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc;
for (auto source : sources) {
sourceTypes.push_back(source.getType());
sourceLoc.push_back(loc);
if (llvm::any_of(
sources, [](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc;
for (auto source : sources) {
sourceTypes.push_back(source.getType());
sourceLoc.push_back(loc);
}
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg);
auto newConcat = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg);
auto newConcat = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
return false;
}
@@ -280,72 +270,6 @@ static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewrite
return cast<Value>(mapped);
}
bool sourceOpernadHasWeightAlways(Operation* op) {
if (op == nullptr)
return false;
Operation* source = nullptr;
do {
if (isa<spatial::SpatCompute>(*op)) {
return false;
}
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(*op)) {
auto tmpSource = extractSliceOp.getSource();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(*op)) {
auto tmpSource = expandShapeOp.getSrc();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto transposeOp = dyn_cast<ONNXTransposeOp>(*op)) {
auto tmpSource = transposeOp.getData();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(*op)) {
auto tmpSource = collapseShapeOp.getSrc();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto constantOp = dyn_cast<arith::ConstantOp>(*op)) {
source = constantOp;
}
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(*op)) {
bool res = false;
for (auto operand : concatOp.getOperands()) {
res |= hasWeightAlways(operand.getDefiningOp());
if (res)
return res;
}
return res;
}
else {
op->dump();
llvm_unreachable("Global instruction not handle in func");
}
}
while (source == nullptr);
if (hasWeightAlways(source))
return true;
return false;
}
// TODO what we want to keep in global?
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
@@ -354,12 +278,8 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
while (keep) {
keep = false;
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
if (isa<spatial::SpatCompute>(instruction) || isa<func::ReturnOp>(instruction)
|| sourceOpernadHasWeightAlways(&instruction))
continue;
keep |= encapsulateSlice(rewriter, loc, &instruction);
keep |= encapsulator<tensor::ExtractSliceOp>(
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
keep |= encapsulator<tensor::ExpandShapeOp>(
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
@@ -375,9 +295,105 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
}
}
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatCompute> trivialComputes;
llvm::SmallSet<spatial::SpatCompute, 8> toErase;
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
if (compute->hasOneUse()) {
auto& use = *compute->getUses().begin();
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
trivialComputes.push_back(compute);
}
while (!trivialComputes.empty()) {
auto compute = trivialComputes.front();
if (compute.use_empty()) {
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
continue;
}
auto& computeUse = *compute->getUses().begin();
auto child = cast<spatial::SpatCompute>(computeUse.getOwner());
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
rewriter.setInsertionPointAfter(compute.getOperation());
auto newCompute =
spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
IRMapping mapper;
auto weightMutableIter = newCompute.getWeightsMutable();
for (auto weight : child.getWeights()) {
auto founded = llvm::find(newCompute.getWeights(), weight);
if (founded == newCompute.getWeights().end()) {
weightMutableIter.append(weight);
auto last = weightMutableIter.end();
last = std::prev(last, 1);
mapper.map(weight, last->get());
}
else {
mapper.map(weight, *founded);
}
}
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
auto newTerminator = newCompute.getBody().front().getTerminator();
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
newTerminator->erase();
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
for (auto& op : child.getBody().front()) {
auto newInst = rewriter.clone(op, mapper);
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
}
child.replaceAllUsesWith(newCompute);
toErase.insert(child);
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
toErase.insert(compute);
if (newCompute->hasOneUse()) {
auto& use = *newCompute->getUses().begin();
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
trivialComputes.push_back(newCompute);
}
}
for (auto compute : toErase) {
for (Value result : compute->getResults())
result.dropAllUses();
compute.erase();
}
}
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) {
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
bool isAlwaysWeight =
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); });
if (isAlwaysWeight)
markWeightAlways(constantOp);
});
}

View File

@@ -289,7 +289,8 @@ static SmallVector<Value> createIm2colRowComputes(Value x,
rowResults.reserve(packedNumRows);
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(packFactor * patchSize)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(packFactor * patchSize)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
rowResults.push_back(
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
@@ -325,9 +326,10 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
else {
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
Value packedOutput = gemmRowArgs.size() == 1
? gemmRowArgs.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
Value packedOutput =
gemmRowArgs.size() == 1
? gemmRowArgs.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
@@ -503,41 +505,38 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
//
// We want to process N pixels at the same time. Instead of doing N separate operations
// of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
// containing N copies of W^T and concatenate N im2col rows into one longer row:
// A_packed: [ceil(numPatches / N), N * patchSize]
// B_packed: [N * patchSize, N * cOut]
// Y_packed: [ceil(numPatches / N), N * cOut]
auto gemmInputRowType = RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
// The im2col compute yields each GEMM input row as a separate result so every GEMM consumes only
// the row it needs instead of receiving a full packed tensor and slicing it locally.
auto gemmInputRowType =
RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
auto gemmOutputRowType =
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
xType,
im2colType,
rowType,
gemmInputRowType,
batchSize,
numChannelsIn,
xHeight,
xWidth,
wHeight,
wWidth,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
outWidth,
patchSize,
numPatches,
numPatchesPerBatch,
effectiveMaxParallelPixels,
rewriter,
loc);
xType,
im2colType,
rowType,
gemmInputRowType,
batchSize,
numChannelsIn,
xHeight,
xWidth,
wHeight,
wWidth,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
outWidth,
patchSize,
numPatches,
numPatchesPerBatch,
effectiveMaxParallelPixels,
rewriter,
loc);
Value gemmB = buildPackedWeight(wDenseAttr,
wTrans,

View File

@@ -2,7 +2,6 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
@@ -15,108 +14,7 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static Value extractBatchMatrix(Value value,
int64_t batchIndex,
int64_t batchSize,
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2)
return value;
auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType());
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, value, offsets, sizes, strides);
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
return tensor::CollapseShapeOp::create(rewriter,
loc,
matrixType,
slice,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
}
static bool isConstantLikeOperand(Value value) {
llvm::SmallPtrSet<Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (definingOp->hasTrait<OpTrait::ConstantLike>())
return true;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
}
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
if (type.getRank() == 2) {
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
}
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
}
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
RankedTensorType transposedType;
SmallVector<int64_t> perm;
if (type.getRank() == 2) {
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
perm = {1, 0};
}
else {
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
perm = {0, 2, 1};
}
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return transposeCompute.getResult(0);
}
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
@@ -126,115 +24,80 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|| !outType.hasStaticShape())
return failure();
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|| (outType.getRank() != 2 && outType.getRank() != 3))
return failure();
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|| !haveStaticPositiveShape(outType.getShape()))
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
return failure();
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
const int64_t batch = std::max(lhsBatch, rhsBatch);
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
const int64_t batch = rhsType.getDimSize(0);
const int64_t k = rhsType.getDimSize(1);
const int64_t n = rhsType.getDimSize(2);
const int64_t m = lhsType.getDimSize(0);
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|| outType.getDimSize(2) != n)
return failure();
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
if (k != rhsK)
return failure();
if (outType.getRank() == 2) {
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
return failure();
}
else {
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
return failure();
}
Location loc = matmulOp.getLoc();
bool useTransposedForm = isConstantLikeOperand(matmulOp.getA()) && !isConstantLikeOperand(matmulOp.getB());
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
Value lhs = matmulOp.getA();
Value rhs = matmulOp.getB();
int64_t lhsBatchForGemm = lhsBatch;
int64_t rhsBatchForGemm = rhsBatch;
int64_t gemmM = m;
int64_t gemmK = k;
int64_t gemmN = n;
if (useTransposedForm) {
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
lhsBatchForGemm = rhsBatch;
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
rhsBatchForGemm = lhsBatch;
gemmM = n;
gemmN = m;
}
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
Value lhsTransposed =
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
if (outType.getRank() == 2) {
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
Value gemmResult = ONNXGemmOp::create(rewriter,
loc,
gemmType,
lhsMatrix,
rhsMatrix,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
if (useTransposedForm)
gemmResult = ONNXTransposeOp::create(rewriter, loc, outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}));
rewriter.replaceOp(matmulOp, gemmResult);
return success();
}
SmallVector<Value> batchResults;
batchResults.reserve(batch);
SmallVector<Value> gemmRows;
gemmRows.reserve(batch * n);
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
Value gemmResult = ONNXGemmOp::create(rewriter,
loc,
gemmType,
lhsMatrix,
rhsMatrix,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
if (useTransposedForm)
gemmResult = ONNXTransposeOp::create(
rewriter,
loc,
RankedTensorType::get({m, n}, outType.getElementType()),
gemmResult,
rewriter.getI64ArrayAttr({1, 0}));
batchResults.push_back(tensor::ExpandShapeOp::create(rewriter,
loc,
batchedOutType,
gemmResult,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
}));
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value rhsSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
loc,
rhsRowType,
rhsSlice,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
auto gemmOp = ONNXGemmOp::create(rewriter,
loc,
gemmRowType,
rhsRow,
lhsTransposed,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
gemmRows.push_back(gemmOp.getY());
}
}
Value result = batchResults.size() == 1
? batchResults.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, batchResults).getResult();
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) {
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs);
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
});
Value gemmOut = concatComputeOp.getResult(0);
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
loc,
gemmExpandedType,
gemmOut,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
rewriter.replaceOp(matmulOp, result);
return success();
}
@@ -243,7 +106,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
} // namespace
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<MatMulToGemm>(ctx);
patterns.insert<MatMulRank3ToGemm>(ctx);
}
} // namespace onnx_mlir

View File

@@ -5,7 +5,6 @@ add_public_tablegen_target(SpatialToPimIncGen)
add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp
Common.cpp
Patterns.cpp
EXCLUDE_FROM_OM_LIBS

View File

@@ -96,8 +96,8 @@ bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
}
mlir::Value
createPimReceiveFromSpatialChannel(PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
mlir::Value createPimReceiveFromSpatialChannel(
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
@@ -127,16 +127,6 @@ SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
}
bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
for (Operation* user : value.getUsers()) {
if (user->getBlock() != operation->getBlock())
return true;
if (operation->isBeforeInBlock(user))
return true;
}
return false;
}
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
assert("Only support operations with a single result" && operation->getNumResults() == 1);
mlir::Value result = operation->getResult(0);
@@ -144,9 +134,8 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter,
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
auto validOperands = make_filter_range(operands, [operation, resultType](mlir::Value operand) {
return operand.getType() == resultType && !hasLaterUserInBlock(operand, operation);
});
auto validOperands =
make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; });
auto bestOperand = validOperands.begin();
if (bestOperand != validOperands.end())

View File

@@ -1,287 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
Location loc = extractSliceOp.getLoc();
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
return failure();
for (auto& uses : extractSliceOp->getUses()) {
if (isa<spatial::SpatCompute>(uses.getOwner())) {
auto spatCompute = cast<spatial::SpatCompute>(uses.getOwner());
if (spatCompute.getInputs().empty())
return failure();
if (uses.getOperandNumber() < spatCompute.getInputs().getBeginOperandIndex())
return failure();
}
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
return failure();
}
}
llvm::DenseMap<spatial::SpatCompute, Value> mapSpatToExtract;
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
auto BBArgIndex = uses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
if (BBArgValue.use_empty())
continue;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatToExtract.contains(spatCompute)) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatCompute, newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatCompute]);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else {
{
auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>();
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatToExtract.contains(spatCompute)) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatCompute, newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatCompute.getOperation());
uses.set(mapSpatToExtract[spatCompute]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
}
}
rewriter.eraseOp(extractSliceOp);
return success();
}
};
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
static int i = 0;
Location loc = constantOp.getLoc();
if (hasWeightAlways(constantOp))
return failure();
if (!isa<func::FuncOp>(constantOp->getParentOp()))
return failure();
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
if (isa<spatial::SpatCompute>(op))
return false;
if (isa<func::FuncOp>(op->getParentOp()))
return true;
return false;
}))
return failure();
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
if (constRankedTensorType) {
mlir::MemRefType memRefType =
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
std::string argName = "const_" + std::to_string(i++);
memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(argName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
constantOp.getValueAttr(),
rewriter.getUnitAttr(),
{});
llvm::DenseMap<spatial::SpatCompute, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute)) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute, toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatCompute]);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else {
{
auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>();
if (!spatCompute)
continue;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute)) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute, toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
constUses.set(mapSpatComputeToConst[spatCompute]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
}
}
}
else if (constantOp.getType().isIntOrIndexOrFloat()) {
llvm::DenseMap<spatial::SpatCompute, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else {
auto parent = constUsers->getParentOfType<spatial::SpatCompute>();
assert(parent && "Global Constant used direcly not within a compute");
if (!mapSpatComputeToConst.contains(parent)) {
rewriter.setInsertionPoint(&parent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({parent, newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[parent]);
}
}
}
auto parent = constantOp->getParentOp();
rewriter.eraseOp(constantOp);
return success();
}
};
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override {
if (funcOp.getArguments().empty())
return failure();
if (llvm::all_of(funcOp.getArguments(),
[](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); }))
return failure();
Location loc = funcOp.getLoc();
for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) {
if (arg.getUses().empty())
continue;
rewriter.setInsertionPoint(funcOp.getOperation());
assert(isa<mlir::RankedTensorType>(arg.getType()));
auto argRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(arg.getType());
mlir::MemRefType memRefType =
mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType());
std::string argName = "arg_" + std::to_string(index);
memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(argName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
{},
{},
{});
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
auto argUser = argUses.getOwner();
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(toTensor);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else {
rewriter.setInsertionPoint(argUser);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
rewriter.startOpModification(argUser);
argUses.set(getGlobalOp);
rewriter.finalizeOpModification(argUser);
}
}
}
return success();
}
};
} // namespace
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
patterns.getContext());
}
} // namespace onnx_mlir

View File

@@ -1,10 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateGlobalTensorToMemrefPatterns(mlir::RewritePatternSet& patterns);
}

View File

@@ -1,26 +1,20 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_os_ostream.h"
#include <cassert>
@@ -29,7 +23,6 @@
#include <utility>
#include "Conversion/ONNXToSpatial/Common.hpp"
#include "Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -58,7 +51,7 @@ struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>>
void runOnOperation() final;
private:
SmallVector<std::function<Value(IRRewriter& rewriter, Location loc)>> outputTensors;
SmallVector<Value> outputTensors;
size_t coreId = 0;
SmallVector<Operation*> operationsToRemove;
@@ -153,21 +146,12 @@ void SpatialToPimPass::runOnOperation() {
scf::SCFDialect,
BuiltinDialect>();
{
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
}
}
{
RewritePatternSet patterns(ctx);
populateGlobalTensorToMemrefPatterns(patterns);
walkAndApplyPatterns(moduleOp, std::move(patterns));
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
}
auto entryFunc = getPimEntryFunc(moduleOp);
@@ -294,7 +278,7 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
auto storedType = cast<ShapedType>(storedValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
auto outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
Value outputTensor = outputTensors[resultIndexInReturn];
if (auto storedOp = storedValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp);
PimMemCopyDevToHostOp::create(rewriter,
@@ -316,8 +300,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
// Store to global memory
Value outputTensor = outputTensors[resultIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
@@ -357,8 +341,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
// Store to global memory
Value outputTensor = outputTensors[concatIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = outputTensors[concatIndexInReturn](rewriter, loc);
PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
@@ -435,16 +419,8 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
if (outTensorOperand == vmmOp.getInput()) {
rewriter.setInsertionPoint(vmmOp);
auto newOutputBuffer =
tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType());
vmmOp.getOutputBufferMutable().assign(newOutputBuffer);
}
else {
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
outTensorOperand.setType(newType);
}
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
outTensorOperand.setType(newType);
resultTensor.setType(newType);
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
@@ -464,35 +440,17 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
outputTensors.reserve(returnOp->getNumOperands());
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
rewriter.setInsertionPointToStart(returnOp->getBlock());
for (auto returnValue : returnOp->getOperands()) {
Operation* returnValueDefiningOp = returnValue.getDefiningOp();
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
assert(!hasWeightAlways(returnValueDefiningOp));
outputTensors.push_back( [returnValue] (IRRewriter& rewriter, Location loc) -> Value { return returnValue; });
outputTensors.push_back(returnValue);
}
else {
auto outRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(returnValue.getType());
mlir::MemRefType memRefType =
mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
std::string outputName = "output_" + std::to_string(index);
rewriter.setInsertionPoint(returnOp.getParentOp());
memref::GlobalOp::create(rewriter,
returnOp.getLoc(),
rewriter.getStringAttr(outputName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
{},
{},
{});
outputTensors.push_back(
[memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
return toTensor.getResult();
});
auto newOutputTensor =
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType()));
outputTensors.push_back(newOutputTensor);
}
}
}
@@ -500,11 +458,11 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
Location loc = funcOp.getLoc();
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(inputTensor.getType());
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(valueToReplace.getType());
Type elementType = tensorType.getElementType();
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
@@ -513,28 +471,85 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
loc,
tensorType,
deviceTensor,
inputTensor,
hostTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult());
};
// Replace input tensors with memRefs
SmallVector<bufferization::ToTensorOp, 8> inputTensors;
for (size_t i = 0; i < funcOp.getNumArguments(); i++) {
BlockArgument tensorArg = funcOp.getArgument(i);
DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i);
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
BlockArgument memRefArg = funcOp.getArgument(i + 1);
Block& block = funcOp.getBody().front();
rewriter.setInsertionPoint(&block.front());
auto toTensorOp =
bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
inputTensors.push_back(toTensorOp);
tensorArg.replaceAllUsesWith(toTensorOp);
if (failed(funcOp.eraseArgument(i)))
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
}
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
for (auto& op : funcOp.getBody().getOps())
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
assert(computeOp.getInputs().size() == 0 && "Already removed from mergeNode and global input handle");
assert(computeOp.getBody().front().getNumArguments() == 0
&& "Already removed from mergeNode and global input handle");
for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
if (getGlobal.getName().starts_with("arg")) {
assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute");
auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin();
insertMemCopyHostToDev(toTensorOpValue, 0);
unsigned numComputeWeights = computeOp.getWeights().size();
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) {
TypedValue<TensorType> tensorSource;
int64_t elementsOffset = 0;
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
continue;
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
ArrayRef<int64_t> sliceStrides = sliceOp.getStaticStrides();
assert("Extracting slice non-contiguous in memory"
&& isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides));
for (size_t i = 0; i < sliceOffsets.size(); i++) {
int64_t partialOffset = sliceOffsets[i];
if (partialOffset != 0)
for (size_t j = i + 1; j < sourceShape.size(); j++)
partialOffset *= sourceShape[j];
elementsOffset += partialOffset;
}
computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource);
sliceOpsToRemove.insert(sliceOp);
}
else
tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
// Compute results must be transferred through channels via send/receive
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
continue;
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset);
}
}
for (auto sliceOp : sliceOpsToRemove)
if (sliceOp->getUses().empty())
rewriter.eraseOp(sliceOp);
return success();
}
@@ -712,13 +727,12 @@ void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter&
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
auto loc = returnOp.getLoc();
for (auto it : llvm::enumerate(originalOperands)) {
size_t orderWithinReturn = it.index();
Operation* returnOperand = it.value().getDefiningOp();
rewriter.setInsertionPoint(returnOp);
rewriter.modifyOpInPlace(returnOp,
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn](rewriter, loc)); });
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); });
Operation* opToErase = returnOperand;
while (opToErase) {

View File

@@ -24,7 +24,7 @@ def PimTensor :
// Execution
//===----------------------------------------------------------------------===//
def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
def PimCoreOp : PimOp<"core", [SingleBlock]> {
let summary = "Execute a block on a PIM core";
let regions = (region SizedRegion<1>:$body);

View File

@@ -178,10 +178,8 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
if (failed(outputBufferOpt))
return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
return success();
}
};
@@ -205,10 +203,8 @@ struct MVMOpInterface : DstBufferizableOpInterfaceExternalModel<MVMOpInterface,
if (failed(outputBufferOpt))
return failure();
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimMVMOp>(
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), *inputOpt, *outputBufferOpt);
return success();
}
};

View File

@@ -3,17 +3,12 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Threading.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "Common/PimCommon.hpp"
#include "Compiler/PimCodeGen.hpp"
#include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
@@ -45,44 +40,14 @@ private:
void PimBufferizationPass::runOnOperation() {
auto moduleOp = getOperation();
// Refactor this into a function
{
auto funcOp = getPimEntryFunc(moduleOp);
auto coreOps = llvm::to_vector(funcOp->getOps<pim::PimCoreOp>());
MLIRContext* ctx = moduleOp.getContext();
// failableParallelForEach will run the lambda in parallel and stop if any thread fails
LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](pim::PimCoreOp coreOp) {
// Again, allocate state LOCALLY per thread/function
bufferization::OneShotBufferizationOptions options;
options.allowUnknownOps = true;
bufferization::BufferizationState state;
if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) {
coreOp.emitError("Failed to bufferize PIM and Spatial ops");
return failure();
}
return success();
});
if (failed(result)) {
moduleOp.emitError("Failed to bufferize-parallel PIM and Spatial ops");
signalPassFailure();
}
funcOp->walk([&](bufferization::ToTensorOp toTensorOp) {
if (llvm::isa_and_present<pim::PimCoreOp>(toTensorOp->getParentOp()))
toTensorOp->setAttr("restrict", UnitAttr::get(ctx));
});
// One-Shot-Bufferization
bufferization::OneShotBufferizationOptions options;
options.allowUnknownOps = true;
bufferization::BufferizationState state;
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
signalPassFailure();
}
// One-Shot-Bufferization
bufferization::OneShotBufferizationOptions options;
options.allowUnknownOps = true;
bufferization::BufferizationState state;
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
signalPassFailure();
}
MLIRContext* ctx = moduleOp.getContext();
@@ -129,8 +94,10 @@ void PimBufferizationPass::runOnOperation() {
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
funcOp.walk([&](PimCoreOp coreOp) {
walkPimMvmVmmWeightUses(coreOp, [&](OpOperand& weightUse) {
Value weight = weightUse.get();
auto annotateWeight = [&](unsigned weightIndex) {
if (weightIndex >= coreOp.getWeights().size())
return;
Value weight = coreOp.getWeights()[weightIndex];
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp)
return;
@@ -138,7 +105,10 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
assert("Weights must be constants" && globalMemrefOp.getConstant());
markWeightAlways(getGlobalOp);
markWeightAlways(globalMemrefOp);
});
};
coreOp.walk([&](PimMVMOp mvmOp) { annotateWeight(mvmOp.getWeightIndex()); });
coreOp.walk([&](PimVMMOp vmmOp) { annotateWeight(vmmOp.getWeightIndex()); });
});
}

View File

@@ -1,4 +1,5 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -13,7 +14,10 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"
@@ -115,10 +119,13 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
}
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
if (auto computeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp()))
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
auto wcomputeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp());
if (wcomputeOp)
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
if (auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp()))
auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp());
if (coreOp)
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
return failure();

View File

@@ -3,17 +3,15 @@
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <iterator>
#include <map>
#include <numeric>
#include <optional>
#include <queue>
#include <set>
#include <utility>
#include <vector>
@@ -30,7 +28,7 @@ using namespace mlir;
namespace {
struct VirtualNode {
SmallVector<size_t, 4> originalComputeIndices;
llvm::SmallVector<size_t, 4> originalComputeIndices;
Weight weight = 0;
CrossbarUsage crossbarUsage = 0;
};
@@ -49,13 +47,11 @@ struct TimingInfo {
struct WindowScheduleResult {
std::vector<std::vector<size_t>> mergeGroups;
CPU cpuCount = 0;
size_t mergedNodeCount = 0;
size_t maxMergeGroupSize = 0;
bool usedAllAvailableCpus = false;
};
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
std::map<std::pair<size_t, size_t>, Weight> edgeWeights;
for (auto [start, end, weight] : edges) {
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
@@ -63,9 +59,11 @@ std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
continue;
auto key = std::make_pair(startIndex, endIndex);
Weight edgeWeight = static_cast<Weight>(weight);
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
if (!inserted.second)
inserted.first->second = std::max(inserted.first->second, edgeWeight);
auto it = edgeWeights.find(key);
if (it == edgeWeights.end())
edgeWeights.insert({key, edgeWeight});
else
it->second = std::max(it->second, edgeWeight);
}
std::vector<IndexedEdge> aggregatedEdges;
@@ -73,15 +71,11 @@ std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
for (auto [key, weight] : edgeWeights)
aggregatedEdges.push_back(
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) {
if (std::get<0>(lhs) != std::get<0>(rhs))
return std::get<0>(lhs) < std::get<0>(rhs);
return std::get<1>(lhs) < std::get<1>(rhs);
});
return aggregatedEdges;
}
VirtualGraph buildInitialVirtualGraph(ArrayRef<SpatCompute> spatComputes, ArrayRef<IndexedEdge> edges) {
VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef<SpatCompute> spatComputes,
llvm::ArrayRef<IndexedEdge> edges) {
VirtualGraph graph;
graph.nodes.reserve(spatComputes.size());
for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) {
@@ -164,27 +158,10 @@ TimingInfo computeTiming(const VirtualGraph& graph) {
return timing;
}
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph& graph) {
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
for (auto [start, end, weight] : graph.edges) {
(void) weight;
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
adjacency[startIndex].push_back(endIndex);
adjacency[endIndex].push_back(startIndex);
}
for (auto& neighbours : adjacency) {
llvm::sort(neighbours);
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
}
return adjacency;
}
std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) {
std::vector<size_t> ranked(timing.aest.size());
std::iota(ranked.begin(), ranked.end(), 0);
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
std::vector<size_t> selectCriticalWindow(const TimingInfo& timing, size_t windowSize) {
std::vector<size_t> selected(timing.aest.size());
std::iota(selected.begin(), selected.end(), 0);
std::stable_sort(selected.begin(), selected.end(), [&](size_t lhs, size_t rhs) {
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
if (lhsSlack != rhsSlack)
@@ -192,85 +169,21 @@ std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const Timing
if (timing.aest[lhs] != timing.aest[rhs])
return timing.aest[lhs] < timing.aest[rhs];
return lhs < rhs;
};
windowSize = std::min(windowSize, ranked.size());
if (windowSize == 0)
return {};
if (windowSize == ranked.size()) {
llvm::sort(ranked, isHigherPriority);
return ranked;
}
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
if (criticalPoolSize < ranked.size())
std::nth_element(
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
std::vector<char> inCriticalPool(ranked.size(), false);
for (size_t i = 0; i < criticalPoolSize; ++i)
inCriticalPool[ranked[i]] = true;
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
std::vector<size_t> selected;
std::vector<char> inWindow(ranked.size(), false);
selected.reserve(windowSize);
struct FrontierEntry {
size_t node;
};
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
auto addToWindow = [&](size_t node, const std::vector<char>& eligible) {
if (inWindow[node])
return;
inWindow[node] = true;
selected.push_back(node);
for (size_t neighbour : adjacency[node])
if (!inWindow[neighbour] && eligible[neighbour])
frontier.push({neighbour});
};
addToWindow(seed, inCriticalPool);
while (!frontier.empty() && selected.size() < windowSize) {
size_t node = frontier.top().node;
frontier.pop();
if (!inWindow[node])
addToWindow(node, inCriticalPool);
}
if (selected.size() < windowSize) {
std::vector<char> anyNode(ranked.size(), true);
for (size_t node : selected)
for (size_t neighbour : adjacency[node])
if (!inWindow[neighbour])
frontier.push({neighbour});
while (!frontier.empty() && selected.size() < windowSize) {
size_t node = frontier.top().node;
frontier.pop();
if (!inWindow[node])
addToWindow(node, anyNode);
}
}
if (selected.size() < windowSize) {
llvm::sort(ranked, isHigherPriority);
for (size_t node : ranked) {
if (selected.size() == windowSize)
break;
if (!inWindow[node]) {
inWindow[node] = true;
selected.push_back(node);
}
}
}
llvm::sort(selected, isHigherPriority);
});
selected.resize(std::min(windowSize, selected.size()));
return selected;
}
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes) {
std::vector<size_t> signature;
for (size_t nodeIndex : selectedNodes) {
const VirtualNode& node = graph.nodes[nodeIndex];
signature.insert(signature.end(), node.originalComputeIndices.begin(), node.originalComputeIndices.end());
}
std::sort(signature.begin(), signature.end());
return signature;
}
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
std::vector<IndexedEdge> windowEdges;
windowEdges.reserve(graph.edges.size());
@@ -284,7 +197,8 @@ std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::
return aggregateEdges(windowEdges);
}
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
WindowScheduleResult
scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes, MLIRContext* context) {
std::vector<Weight> windowWeights;
std::vector<CrossbarUsage> windowCrossbarUsage;
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
@@ -304,47 +218,27 @@ WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t>
windowGraph.runDcp();
WindowScheduleResult result;
result.cpuCount = windowGraph.cpuCount();
result.usedAllAvailableCpus = windowGraph.cpuCount() >= windowGraph.getMaxCpuCount();
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
if (scheduledTasks.size() < 2)
continue;
result.mergedNodeCount += scheduledTasks.size();
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
std::vector<size_t> mergeGroup;
mergeGroup.reserve(scheduledTasks.size());
for (const auto& task : scheduledTasks)
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
std::sort(mergeGroup.begin(), mergeGroup.end());
result.mergeGroups.push_back(std::move(mergeGroup));
}
return result;
}
bool coarsenGraph(const VirtualGraph& graph,
ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph& coarsenedGraph,
std::vector<size_t>& oldToNewNode) {
TimingInfo timing = computeTiming(graph);
std::vector<size_t> topologicalRank(graph.nodes.size());
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
if (timing.valid)
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
topologicalRank[nodeIndex] = rank;
std::vector<std::vector<size_t>> orderedMergeGroups;
orderedMergeGroups.reserve(mergeGroups.size());
for (const auto& mergeGroup : mergeGroups) {
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
if (topologicalRank[lhs] != topologicalRank[rhs])
return topologicalRank[lhs] < topologicalRank[rhs];
return lhs < rhs;
});
}
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph& coarsenedGraph) {
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
for (auto [groupIndex, mergeGroup] : llvm::enumerate(mergeGroups)) {
if (mergeGroup.size() < 2)
continue;
for (size_t nodeIndex : mergeGroup) {
@@ -353,21 +247,18 @@ bool coarsenGraph(const VirtualGraph& graph,
}
}
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
std::vector<size_t> newNodeRank;
oldToNewNode.assign(graph.nodes.size(), 0);
std::vector<std::optional<size_t>> mergeGroupToNewNode(mergeGroups.size());
std::vector<size_t> oldToNewNode(graph.nodes.size(), 0);
bool mergedAny = false;
coarsenedGraph.nodes.clear();
coarsenedGraph.edges.clear();
coarsenedGraph.nodes.reserve(graph.nodes.size());
newNodeRank.reserve(graph.nodes.size());
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
if (mergeGroupIndex == -1) {
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
newNodeRank.push_back(topologicalRank[nodeIndex]);
continue;
}
@@ -378,7 +269,7 @@ bool coarsenGraph(const VirtualGraph& graph,
}
VirtualNode mergedNode;
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
for (size_t memberIndex : mergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
const VirtualNode& memberNode = graph.nodes[memberIndex];
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
memberNode.originalComputeIndices.end());
@@ -389,9 +280,8 @@ bool coarsenGraph(const VirtualGraph& graph,
mergedAny = true;
newNodeIndex = coarsenedGraph.nodes.size();
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
for (size_t memberIndex : mergeGroups[static_cast<size_t>(mergeGroupIndex)])
oldToNewNode[memberIndex] = *newNodeIndex;
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
coarsenedGraph.nodes.push_back(std::move(mergedNode));
}
@@ -405,65 +295,81 @@ bool coarsenGraph(const VirtualGraph& graph,
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
if (newStart == newEnd)
continue;
if (newNodeRank[newStart] >= newNodeRank[newEnd])
continue;
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
}
coarsenedGraph.edges = aggregateEdges(remappedEdges);
return true;
return computeTiming(coarsenedGraph).valid;
}
constexpr CPU kDefaultMaxCpuCount = 1000;
bool coarsenGraphWithFallback(const VirtualGraph& graph,
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph& coarsenedGraph) {
if (coarsenGraph(graph, mergeGroups, coarsenedGraph))
return true;
CPU getVirtualGraphMaxCpuCount() {
if (coresCount.getValue() > 0)
return static_cast<CPU>(coresCount.getValue());
return kDefaultMaxCpuCount;
std::vector<size_t> orderedGroupIndices(mergeGroups.size());
std::iota(orderedGroupIndices.begin(), orderedGroupIndices.end(), 0);
std::stable_sort(orderedGroupIndices.begin(), orderedGroupIndices.end(), [&](size_t lhs, size_t rhs) {
return mergeGroups[lhs].size() > mergeGroups[rhs].size();
});
std::vector<std::vector<size_t>> acceptedMergeGroups;
acceptedMergeGroups.reserve(mergeGroups.size());
for (size_t groupIndex : orderedGroupIndices) {
std::vector<std::vector<size_t>> candidateMergeGroups = acceptedMergeGroups;
candidateMergeGroups.push_back(mergeGroups[groupIndex]);
VirtualGraph candidateGraph;
if (!coarsenGraph(graph, candidateMergeGroups, candidateGraph))
continue;
acceptedMergeGroups = std::move(candidateMergeGroups);
coarsenedGraph = std::move(candidateGraph);
}
return !acceptedMergeGroups.empty();
}
size_t getDcpCoarseningWindowSize(size_t nodeCount) {
size_t windowSize = std::min(dcpCriticalWindowSize.getValue(), nodeCount);
CPU maxCpuCount = std::max<CPU>(1, getVirtualGraphMaxCpuCount());
if (nodeCount > static_cast<size_t>(maxCpuCount))
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
return windowSize;
}
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef<SpatCompute> spatComputes) {
DCPAnalysisResult result;
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::ArrayRef<IndexedEdge> edges) {
VirtualGraph graph;
graph.nodes.resize(computeCount);
graph.edges = aggregateEdges(edges);
TimingInfo timing = computeTiming(graph);
std::vector<size_t> virtualNodeOrder;
if (timing.valid) {
virtualNodeOrder = std::move(timing.topologicalOrder);
}
else {
virtualNodeOrder.resize(graph.nodes.size());
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
}
if (timing.valid)
return timing.topologicalOrder;
std::vector<size_t> originalComputeToCpu(spatComputes.size(), 0);
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
std::vector<size_t> fallbackOrder(computeCount);
std::iota(fallbackOrder.begin(), fallbackOrder.end(), 0);
return fallbackOrder;
}
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
llvm::ArrayRef<SpatCompute> spatComputes,
llvm::ArrayRef<IndexedEdge> originalEdges) {
DCPAnalysisResult result;
std::vector<size_t> originalToVirtualNode(spatComputes.size(), 0);
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
for (size_t originalIndex : virtualNode.originalComputeIndices)
originalComputeToCpu[originalIndex] = cpu;
}
originalToVirtualNode[originalIndex] = virtualNodeIndex;
result.dominanceOrderCompute.reserve(spatComputes.size());
for (auto [originalIndex, spatCompute] : llvm::enumerate(spatComputes)) {
size_t cpu = originalComputeToCpu[originalIndex];
auto dominanceOrder = computeOriginalTopologicalOrder(spatComputes.size(), originalEdges);
result.dominanceOrderCompute.reserve(dominanceOrder.size());
for (size_t originalIndex : dominanceOrder) {
SpatCompute spatCompute = spatComputes[originalIndex];
size_t cpu = originalToVirtualNode[originalIndex];
result.dominanceOrderCompute.push_back(spatCompute);
result.computeToCpuMap[spatCompute] = cpu;
result.cpuToLastComputeMap[cpu] = spatCompute;
}
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
result.isLastComputeOfCpu.insert(lastCompute);
for (auto [cpu, lastCompute] : result.cpuToLastComputeMap)
result.isLastComputeOfCpu.insert(lastCompute);
return result;
}
DCPAnalysisResult runLegacyDcp(ArrayRef<SpatCompute> spatComputes, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatCompute> spatComputes,
llvm::ArrayRef<IndexedEdge> edges,
MLIRContext* context) {
GraphDCP graphDCP(spatComputes, edges);
if (coresCount.getValue() > 0)
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
@@ -477,12 +383,12 @@ DCPAnalysisResult runLegacyDcp(ArrayRef<SpatCompute> spatComputes, ArrayRef<Inde
SpatCompute getOriginalSpatCompute(Operation* op) {
if (!op)
return {};
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
op = extract.getSource().getDefiningOp();
if (!op)
return {};
}
if (auto res = dyn_cast<SpatCompute>(op))
if (auto res = llvm::dyn_cast<SpatCompute>(op))
return res;
return {};
}
@@ -509,74 +415,32 @@ DCPAnalysisResult DCPAnalysis::run() {
return runLegacyDcp(spatComputes, edges, entryOp->getContext());
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatComputes, edges);
size_t iteration = 0;
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
size_t oldNodeCount = virtualGraph.nodes.size();
std::set<std::vector<size_t>> seenCriticalWindows;
while (virtualGraph.nodes.size() > 1) {
TimingInfo timing = computeTiming(virtualGraph);
if (!timing.valid)
break;
auto selectedNodes = selectCriticalWindow(timing, dcpCriticalWindowSize.getValue());
if (selectedNodes.size() < 2)
break;
if (!seenCriticalWindows.insert(getOriginalSignature(virtualGraph, selectedNodes)).second)
break;
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
if (windowSchedule.mergeGroups.empty()) {
if (oldNodeCount >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
iteration,
oldNodeCount,
selectedNodes.size(),
windowSchedule.cpuCount);
return false;
}
if (windowSchedule.mergeGroups.empty())
break;
VirtualGraph coarsenedGraph;
std::vector<size_t> oldToNewNode;
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
return false;
if (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
iteration,
oldNodeCount,
selectedNodes.size(),
windowSchedule.cpuCount,
windowSchedule.mergeGroups.size(),
windowSchedule.mergedNodeCount,
windowSchedule.maxMergeGroupSize,
coarsenedGraph.nodes.size(),
oldNodeCount - coarsenedGraph.nodes.size());
if (!coarsenGraphWithFallback(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph))
break;
virtualGraph = std::move(coarsenedGraph);
return true;
};
while (virtualGraph.nodes.size() > 1) {
iteration++;
TimingInfo timing = computeTiming(virtualGraph);
if (!timing.valid) {
if (virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
if (windowSchedule.usedAllAvailableCpus)
break;
}
SmallVector<size_t> selectedNodes;
auto criticalWindow =
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size()));
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
if (selectedNodes.size() < 2) {
if (virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
iteration,
virtualGraph.nodes.size(),
selectedNodes.size());
break;
}
if (tryCoarsenSelectedNodes(selectedNodes))
continue;
if (virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
break;
}
return buildResultFromVirtualGraph(virtualGraph, spatComputes);
return buildResultFromVirtualGraph(virtualGraph, spatComputes, edges);
}
} // namespace spatial

View File

@@ -38,14 +38,11 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cstdint>
#include <cstdio>
#include <queue>
#include <vector>
#include "DCPAnalysis.hpp"
@@ -63,7 +60,6 @@ namespace {
// Coarse-grained phase timers printed when DCP_SELECT_PROFILE is set.
struct SelectTimers {
double findSlot = 0.0;
double dedup = 0.0;
double precheck = 0.0;
double snapshotInsertUpdate = 0.0;
double childSlot = 0.0;
@@ -74,19 +70,9 @@ struct SelectTimers {
long tasksProcessed = 0;
void dump(const char* label) const {
std::fprintf(stderr,
"[selectProfile:%s] tasks=%ld dedup=%.2fs findSlot=%.2fs precheck=%.2fs snapUpd=%.2fs "
"childSlot=%.2fs rollback=%.2fs iter=%ld precheckPass=%ld dcplPass=%ld\n",
label,
tasksProcessed,
dedup,
findSlot,
precheck,
snapshotInsertUpdate,
childSlot,
rollbackRestore,
iterations,
passedPrecheck,
passedDcpl);
"[selectProfile:%s] tasks=%ld findSlot=%.2fs precheck=%.2fs snapUpd=%.2fs childSlot=%.2fs rollback=%.2fs iter=%ld precheckPass=%ld dcplPass=%ld\n",
label, tasksProcessed, findSlot, precheck, snapshotInsertUpdate, childSlot,
rollbackRestore, iterations, passedPrecheck, passedDcpl);
}
~SelectTimers() {
if (std::getenv("DCP_SELECT_PROFILE"))
@@ -97,101 +83,6 @@ static SelectTimers gSelectTimers;
} // namespace
#endif
namespace {
uint64_t mixHash(uint64_t seed, uint64_t value) {
seed ^= value + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
return seed;
}
uint64_t finishHash(uint64_t seed) {
seed ^= seed >> 33;
seed *= 0xff51afd7ed558ccdULL;
seed ^= seed >> 33;
seed *= 0xc4ceb9fe1a85ec53ULL;
seed ^= seed >> 33;
return seed;
}
uint64_t hashEdgeSignature(uint64_t neighborHash, Weight weight, uint64_t direction) {
uint64_t hash = mixHash(0x84222325cbf29ce4ULL, direction);
hash = mixHash(hash, neighborHash);
hash = mixHash(hash, static_cast<uint64_t>(weight));
return finishHash(hash);
}
struct CpuAestCache {
Time defaultAest = 0;
llvm::SmallDenseMap<CPU, Time, 8> colocatedParentAests;
Time get(CPU cpu) const {
auto it = colocatedParentAests.find(cpu);
if (it == colocatedParentAests.end())
return defaultAest;
return it->second;
}
};
struct CpuTimeMax {
CPU cpu = -1;
Time time = 0;
};
void updateCpuTimeMax(CpuTimeMax& first, CpuTimeMax& second, CPU cpu, Time time) {
if (first.cpu == cpu) {
first.time = std::max(first.time, time);
return;
}
if (second.cpu == cpu) {
second.time = std::max(second.time, time);
if (second.time > first.time)
std::swap(first, second);
return;
}
if (time >= first.time) {
second = first;
first = {cpu, time};
return;
}
if (time > second.time)
second = {cpu, time};
}
CpuAestCache computeCpuAestCache(TaskDCP* task) {
CpuAestCache cache;
llvm::SmallDenseMap<CPU, Time, 8> transferAestByCpu;
llvm::SmallDenseMap<CPU, Time, 8> localAestByCpu;
Time unscheduledTransferAest = 0;
for (const Edge& parentEdge : task->parents) {
Time parentFinish = addOrMax(parentEdge.first->getAest(), parentEdge.first->getWeight());
Time transferAest = addOrMax(parentFinish, getTransferCost(parentEdge.first, task));
if (std::optional<CPU> parentCpu = parentEdge.first->getCpu()) {
Time& cpuTransferAest = transferAestByCpu[*parentCpu];
cpuTransferAest = std::max(cpuTransferAest, transferAest);
Time& cpuLocalAest = localAestByCpu[*parentCpu];
cpuLocalAest = std::max(cpuLocalAest, parentFinish);
continue;
}
unscheduledTransferAest = std::max(unscheduledTransferAest, transferAest);
}
CpuTimeMax firstOther {-1, unscheduledTransferAest};
CpuTimeMax secondOther {-1, 0};
for (const auto& entry : transferAestByCpu)
updateCpuTimeMax(firstOther, secondOther, entry.first, entry.second);
cache.defaultAest = firstOther.time;
for (const auto& entry : localAestByCpu) {
CPU cpu = entry.first;
Time bestNonLocalParentAest = firstOther.cpu == cpu ? secondOther.time : firstOther.time;
cache.colocatedParentAests[cpu] = std::max(bestNonLocalParentAest, entry.second);
}
return cache;
}
} // namespace
//===----------------------------------------------------------------------===//
// Edge manipulation
//===----------------------------------------------------------------------===//
@@ -265,49 +156,6 @@ std::vector<TaskDCP*> GraphDCP::getRoots() {
return tmp;
}
void GraphDCP::initTaskStructureHashes() {
taskStructureHashes.resize(nodes.size());
for (auto [index, task] : llvm::enumerate(nodes)) {
uint64_t hash = mixHash(0x7442b1129fd01363ULL, static_cast<uint64_t>(task.getWeight()));
hash = mixHash(hash, static_cast<uint64_t>(task.getCrossbarUsage()));
taskStructureHashes[index] = finishHash(hash);
}
std::vector<uint64_t> nextHashes(nodes.size());
std::vector<uint64_t> edgeHashes;
for (int iteration = 0; iteration < 4; ++iteration) {
for (auto [index, task] : llvm::enumerate(nodes)) {
uint64_t hash = mixHash(0x464dcab27ac82291ULL, taskStructureHashes[index]);
edgeHashes.clear();
edgeHashes.reserve(task.parents.size() + task.children.size());
for (const Edge& parent : task.parents)
if (!parent.isScheduling)
edgeHashes.push_back(
hashEdgeSignature(taskStructureHashes[getNodeIndex(parent.first)], parent.second, /*direction=*/0));
for (const Edge& child : task.children)
if (!child.isScheduling)
edgeHashes.push_back(
hashEdgeSignature(taskStructureHashes[getNodeIndex(child.first)], child.second, /*direction=*/1));
llvm::sort(edgeHashes);
hash = mixHash(hash, static_cast<uint64_t>(edgeHashes.size()));
for (uint64_t edgeHash : edgeHashes)
hash = mixHash(hash, edgeHash);
nextHashes[index] = finishHash(hash);
}
taskStructureHashes.swap(nextHashes);
}
}
// Compact dedup key for CPU `c` vs `candidate`: mixes candidateAest, crossbar
// usage, and the incremental cpu structure hash. No heap allocation.
uint64_t GraphDCP::computeCpuCandidateKey(Time candidateAest, CPU cpu) {
uint64_t hash = mixHash(0xd6e8feb86659fd93ULL, static_cast<uint64_t>(candidateAest));
hash = mixHash(hash, static_cast<uint64_t>(getCpuCrossbarUsage(cpu)));
auto it = cpuStructureHashes.find(cpu);
hash = mixHash(hash, it != cpuStructureHashes.end() ? it->second : 0ULL);
return finishHash(hash);
}
// Inserts `task` at `position` on `cpu`, wiring up scheduling edges with the
// neighbouring tasks and keeping the global topological order consistent.
TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position) {
@@ -316,7 +164,6 @@ TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position)
task->setCpu(cpu);
task->setWeight(scheduledWeight);
reserveTaskCrossbars(cpu, task);
cpuStructureHashes[cpu] ^= taskStructureHashes[getNodeIndex(task)];
auto& tasksInCpu = getOrCreateCpuTasks(cpu);
unsigned int numCpuTasks = tasksInCpu.size();
assert(position <= numCpuTasks && "Inserting in a not valid position");
@@ -354,7 +201,6 @@ TaskInsertion GraphDCP::insertTaskInCPU(CPU cpu, TaskDCP* task, size_t position)
void GraphDCP::removeTaskFromCPU(CPU cpu, TaskDCP* task) {
releaseTaskCrossbars(cpu, task);
cpuStructureHashes[cpu] ^= taskStructureHashes[getNodeIndex(task)];
task->resetCpu();
task->resetWeight();
auto& scheduledTasks = getOrCreateCpuTasks(cpu);
@@ -425,21 +271,6 @@ bool GraphDCP::wouldExhaustCrossbarCapacity(CPU cpu, const TaskDCP* task) const
return nextUsage >= getCpuCrossbarCapacity();
}
size_t GraphDCP::crossbarsUsed() const {
CrossbarUsage crossbarEdge = static_cast<CrossbarUsage>(onnx_mlir::crossbarSize.getValue());
CrossbarUsage crossbarArea = crossbarEdge * crossbarEdge;
if (crossbarArea == 0)
return 0;
CrossbarUsage totalArea = 0;
for (const auto& [cpu, usage] : cpuCrossbarUsage)
totalArea = checkedAdd(totalArea, usage);
return static_cast<size_t>(totalArea / crossbarArea);
}
size_t GraphDCP::crossbarsAvailable() const {
return static_cast<size_t>(lastCpu) * onnx_mlir::crossbarCountInCore.getValue();
}
//===----------------------------------------------------------------------===//
// AEST / ALST computation
//===----------------------------------------------------------------------===//
@@ -625,9 +456,9 @@ void GraphDCP::updateAestFromTaskWithDescendants(TaskDCP* task, llvm::ArrayRef<T
for (TaskDCP* descendant : descendantsTopoOrder)
recomputeAest(descendant);
const bool oldMaxInvalidated =
maxCompletionTask != nullptr
&& (maxCompletionTask == task || llvm::is_contained(descendantsTopoOrder, maxCompletionTask));
const bool oldMaxInvalidated = maxCompletionTask != nullptr
&& (maxCompletionTask == task
|| llvm::is_contained(descendantsTopoOrder, maxCompletionTask));
if (oldMaxInvalidated) {
// The pre-update max came from a modified task; its completion has moved
// upward, so modifiedMaxCompletion is an upper bound covering it. The
@@ -692,9 +523,9 @@ bool GraphDCP::tryUpdateAestWithinBudget(TaskDCP* task,
if (!process(descendant))
return false;
const bool oldMaxInvalidated =
maxCompletionTask != nullptr
&& (maxCompletionTask == task || llvm::is_contained(descendantsTopoOrder, maxCompletionTask));
const bool oldMaxInvalidated = maxCompletionTask != nullptr
&& (maxCompletionTask == task
|| llvm::is_contained(descendantsTopoOrder, maxCompletionTask));
if (oldMaxInvalidated) {
dcpl = modifiedMaxCompletion;
maxCompletion = modifiedMaxCompletion;
@@ -715,109 +546,6 @@ bool GraphDCP::tryUpdateAestWithinBudget(TaskDCP* task,
return true;
}
// Incrementally refreshes ALST after `task` was placed. The set of nodes whose
// ALST is structurally affected by the insertion is exactly
// `relations.ancestors {task}`: the task's outgoing transfer costs to
// same-CPU real children become 0, and new scheduling edges create parent
// relationships between `task` and its same-CPU neighbors. Every other node
// keeps its relative distance to the sink boundary and only absorbs the
// signed DCPL delta captured between `oldDcpl` and the now-updated `dcpl`.
void GraphDCP::updateAlstFromScheduledTask(TaskDCP* task, const CandidateRelations& relations, Time oldDcpl) {
Time newDcpl = getDcpl();
// If the AEST update saturated dcpl (e.g. rescue placement on a
// crossbar-exhausted CPU sets task weight to UINT64_MAX), the shift delta
// would be meaningless. Fall back to a full recompute for this step only.
if (newDcpl == std::numeric_limits<Time>::max()) {
initAlst();
return;
}
if (newDcpl != oldDcpl) {
const bool increased = newDcpl > oldDcpl;
const Time delta = increased ? (newDcpl - oldDcpl) : (oldDcpl - newDcpl);
for (TaskDCP& node : topologicalOrder) {
if (&node == task || relations.ancestors.contains(&node))
continue;
Time alst = node.getAlst();
node.setAlst(increased ? addOrMax(alst, delta) : subtractOrZero(alst, delta));
}
}
auto recomputeAlst = [&](TaskDCP* node) {
Time minAlst = std::numeric_limits<Time>::max();
if (!node->hasChildren())
minAlst = subtractOrZero(newDcpl, node->getWeight());
for (const Edge& childEdge : node->children)
minAlst = std::min(minAlst,
subtractOrZero(childEdge.first->getAlst(),
addOrMax(node->getWeight(), getTransferCost(node, childEdge.first))));
node->setAlst(minAlst);
};
// Walk the backward cone with a pending-children counter so that every
// ancestor is recomputed only after all of its affected children have
// been refreshed. This is resilient to staleness in the global
// `topologicalOrder` relative to freshly added scheduling edges.
llvm::DenseSet<TaskDCP*> affected = relations.ancestors;
affected.insert(task);
llvm::DenseMap<TaskDCP*, int> pendingAffectedChildren;
pendingAffectedChildren.reserve(affected.size());
std::vector<TaskDCP*> worklist;
worklist.reserve(affected.size());
for (TaskDCP* node : affected) {
int count = 0;
for (const Edge& childEdge : node->children)
if (affected.contains(childEdge.first))
count++;
pendingAffectedChildren[node] = count;
if (count == 0)
worklist.push_back(node);
}
while (!worklist.empty()) {
TaskDCP* node = worklist.back();
worklist.pop_back();
recomputeAlst(node);
for (const Edge& parentEdge : node->parents) {
if (!affected.contains(parentEdge.first))
continue;
auto it = pendingAffectedChildren.find(parentEdge.first);
assert(it != pendingAffectedChildren.end());
if (--it->second == 0)
worklist.push_back(parentEdge.first);
}
}
// Opt-in consistency check: verifies the incremental ALST result against a
// full initAlst() recomputation. Very expensive (O(V+E) per placement) - only
// enable when investigating suspected drift.
#ifdef DCP_DEBUG_CHECK_ALST
std::vector<Time> afterIncremental(nodes.size());
for (size_t i = 0; i < nodes.size(); ++i)
afterIncremental[i] = nodes[i].getAlst();
initAlst();
bool mismatched = false;
for (size_t i = 0; i < nodes.size(); ++i) {
if (afterIncremental[i] != nodes[i].getAlst()) {
if (!mismatched) {
llvm::errs() << "[alst-mismatch] placed=" << getNodeIndex(task) << " oldDcpl=" << oldDcpl
<< " newDcpl=" << newDcpl << " ancestors={";
for (TaskDCP* a : relations.ancestors)
llvm::errs() << getNodeIndex(a) << ",";
llvm::errs() << "}\n";
mismatched = true;
}
llvm::errs() << " node=" << i << " incremental=" << afterIncremental[i] << " full=" << nodes[i].getAlst()
<< " weight=" << nodes[i].getWeight()
<< " cpu=" << (nodes[i].isScheduled() ? (int) *nodes[i].getCpu() : -1) << " children=[";
for (const Edge& e : nodes[i].children)
llvm::errs() << getNodeIndex(e.first) << (e.isScheduling ? "s" : "")
<< "(tc=" << getTransferCost(&nodes[i], e.first) << ",alst=" << e.first->getAlst() << "),";
llvm::errs() << "]\n";
}
}
#endif
}
// Computes a localised ALST: only ancestors of the candidate (plus the
// candidate itself) get recomputed, every other task keeps its current ALST.
// Processes nodes in reverse dependency order using a pending-children
@@ -1177,6 +905,32 @@ GraphDCP::FindSlot GraphDCP::findSlotWithFixedFinalTime(
// Candidate selection and processor assignment
//===----------------------------------------------------------------------===//
// Lowest slack wins; earliest AEST breaks ties. Critical-path tasks (zero
// slack) naturally float to the front.
TaskDCP* GraphDCP::findCandidate(const std::vector<TaskDCP*>& readyNodes) {
auto findBestNode = [](auto lft, auto rgt) {
Time leftSlack = slackOrZero((*lft)->getAest(), (*lft)->getAlst());
Time rightSlack = slackOrZero((*rgt)->getAest(), (*rgt)->getAlst());
if (leftSlack < rightSlack)
return lft;
if (rightSlack < leftSlack)
return rgt;
if ((*lft)->getAest() < (*rgt)->getAest())
return lft;
return rgt;
};
assert(!readyNodes.empty() && "expected at least one ready node");
auto validNode = readyNodes.begin();
auto bestNode = validNode;
while (validNode != readyNodes.end()) {
bestNode = findBestNode(validNode, bestNode);
std::advance(validNode, 1);
}
return *bestNode;
}
// Picks the best CPU + slot for `candidate`:
// * Phase 1 (parallel, read-only): call findSlot on every candidate CPU.
// * Phase 2 (sequential): process CPUs in ascending slot.aest order. For
@@ -1185,7 +939,7 @@ GraphDCP::FindSlot GraphDCP::findSlotWithFixedFinalTime(
// evaluate a slot for the smallest-slack child, then roll back.
// * Rescue (sequential): if nothing fit, grow the CPU count if allowed,
// otherwise pick the CPU that leads to the smallest DCPL increase.
GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
void GraphDCP::selectProcessor(TaskDCP* candidate, bool push) {
CandidateRelations relations = dcp_graph::computeCandidateRelations(candidate);
relations.descendantsTopoOrder.reserve(relations.descendants.size());
for (auto it = candidate->getTopologicalIterator(); it != topologicalOrder.end(); ++it) {
@@ -1205,43 +959,22 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
const CrossbarUsage candidateFootprint = getTaskCrossbarFootprint(candidate);
const bool candidateHasCrossbar = candidateFootprint != 0;
const CrossbarUsage cpuCapacity = candidateHasCrossbar ? getCpuCrossbarCapacity() : 0;
DCP_DEBUG_IF(auto dedupStart = std::chrono::steady_clock::now();)
CpuAestCache cpuAests = computeCpuAestCache(candidate);
DCP_DEBUG_IF(const bool checkCpuAestCache = std::getenv("DCP_CHECK_CPU_AEST_CACHE") != nullptr;)
llvm::SmallDenseSet<uint64_t, 32> seenProcessorKeys;
seenProcessorKeys.reserve(static_cast<size_t>(topCpu + 1));
for (CPU c = 0; c <= topCpu; c++) {
if (candidateHasCrossbar && c != getLastCpu()) {
CrossbarUsage nextUsage = checkedAdd(getCpuCrossbarUsage(c), candidateFootprint);
if (nextUsage >= cpuCapacity)
continue;
}
Time candidateAest = cpuAests.get(c);
DCP_DEBUG_IF(if (checkCpuAestCache) {
Time recomputedAest = computeAestOnCpu(candidate, c);
if (candidateAest != recomputedAest) {
std::fprintf(stderr,
"[DCP_CHECK_CPU_AEST_CACHE] mismatch candidate=%zu cpu=%d cached=%llu recomputed=%llu\n",
getNodeIndex(candidate),
c,
static_cast<unsigned long long>(candidateAest),
static_cast<unsigned long long>(recomputedAest));
llvm::report_fatal_error("DCP CPU AEST cache mismatch");
}
})
if (!seenProcessorKeys.insert(computeCpuCandidateKey(candidateAest, c)).second)
continue;
processors.push_back(c);
}
DCP_DEBUG_IF(gSelectTimers.dedup +=
std::chrono::duration<double>(std::chrono::steady_clock::now() - dedupStart).count();)
if (processors.empty()) {
// processors.empty() implies !canCreateNewCpu: a fresh CPU always passes
// the crossbar filter and would have been added. Reaching here means every
// existing CPU is crossbar-exhausted and the task requires crossbar
// capacity — the placement is impossible.
llvm::report_fatal_error("DCP scheduler: crossbar capacity exhausted on all CPUs; "
"cannot schedule task that requires crossbar allocation");
CPU bestCpu = canCreateNewCpu ? getLastCpu() : 0;
FindSlot bestSlot = {computeAestOnCpu(candidate, bestCpu), static_cast<int>(getOrCreateCpuTasks(bestCpu).size())};
if (canCreateNewCpu)
incrementLastCpu();
insertTaskInCPU(bestCpu, candidate, bestSlot.index);
return;
}
// Phase 1: parallel findSlot sweep (read-only over graph state).
@@ -1267,20 +1000,21 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
for (size_t i = 0; i < processors.size(); ++i)
sweep(i);
DCP_DEBUG_IF(gSelectTimers.findSlot +=
std::chrono::duration<double>(std::chrono::steady_clock::now() - sweepStart).count();)
std::chrono::duration<double>(std::chrono::steady_clock::now() - sweepStart).count();)
#ifdef DCP_DEBUG_ENABLED
{
static bool reported = false;
if (!reported) {
reported = true;
std::fprintf(
stderr,
"[dcp] selectProcessor parallel sweep: context=%p mt=%d procs=%zu pool=%u\n",
(void*) context,
context != nullptr ? (int) context->isMultithreadingEnabled() : -1,
processors.size(),
context != nullptr && context->isMultithreadingEnabled() ? context->getThreadPool().getMaxConcurrency() : 0u);
std::fprintf(stderr,
"[dcp] selectProcessor parallel sweep: context=%p mt=%d procs=%zu pool=%u\n",
(void*) context,
context != nullptr ? (int) context->isMultithreadingEnabled() : -1,
processors.size(),
context != nullptr && context->isMultithreadingEnabled()
? context->getThreadPool().getMaxConcurrency()
: 0u);
}
}
#endif
@@ -1321,10 +1055,9 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
DCP_DEBUG_IF(auto t2 = std::chrono::steady_clock::now();)
Weight candidateWeight = candidate->computeWeightOnCpu(this, currentCpu);
Time candidateCompletion = addOrMax(slot.aest, candidateWeight);
bool skip =
(!emptyCpu && candidateCompletion > currentDcpl) || addOrMax(slot.aest, candidateCompletion) >= bestComposite;
DCP_DEBUG_IF(gSelectTimers.precheck +=
std::chrono::duration<double>(std::chrono::steady_clock::now() - t2).count();)
bool skip = (!emptyCpu && candidateCompletion > currentDcpl)
|| addOrMax(slot.aest, candidateCompletion) >= bestComposite;
DCP_DEBUG_IF(gSelectTimers.precheck += std::chrono::duration<double>(std::chrono::steady_clock::now() - t2).count();)
if (skip)
continue;
DCP_DEBUG_IF(++gSelectTimers.passedPrecheck;)
@@ -1340,8 +1073,8 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
scheduleSnapshot = dcp_graph::captureLocalScheduleState(
candidate, relations.descendants, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
taskInsertion = insertTaskInCPU(currentCpu, candidate, slot.index);
bool withinBudget =
tryUpdateAestWithinBudget(candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder), currentDcpl);
bool withinBudget = tryUpdateAestWithinBudget(
candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder), currentDcpl);
if (!withinBudget) {
DCP_DEBUG_IF(auto t4 = std::chrono::steady_clock::now();)
taskInsertion.rollBack();
@@ -1354,7 +1087,7 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
}
}
DCP_DEBUG_IF(gSelectTimers.snapshotInsertUpdate +=
std::chrono::duration<double>(std::chrono::steady_clock::now() - t3).count();)
std::chrono::duration<double>(std::chrono::steady_clock::now() - t3).count();)
DCP_DEBUG_IF(++gSelectTimers.passedDcpl;)
// Pick the tightest unscheduled child (smallest slack) and measure what
@@ -1402,7 +1135,7 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
dcp_graph::restoreLocalScheduleState(
scheduleSnapshot, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
DCP_DEBUG_IF(gSelectTimers.rollbackRestore +=
std::chrono::duration<double>(std::chrono::steady_clock::now() - t6).count();)
std::chrono::duration<double>(std::chrono::steady_clock::now() - t6).count();)
}
}
@@ -1417,9 +1150,7 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
else {
Time bestDcpl = std::numeric_limits<Time>::max();
Time currentDcpl = getDcpl();
for (CPU c : processors) {
if (c == getLastCpu())
continue;
for (CPU c = 0; c < getLastCpu(); c++) {
auto slot = findSlot(candidate, c, false, relations);
if (slot.aest == std::numeric_limits<Time>::max())
slot = findSlot(candidate, c, true, relations);
@@ -1428,7 +1159,8 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
// Cheap lower bound: post-insertion DCPL is at least max(currentDcpl,
// candidate completion on this slot). Skip CPUs already worse than
// the best seen.
Time lowerBound = std::max(currentDcpl, addOrMax(slot.aest, candidate->computeWeightOnCpu(this, c)));
Time lowerBound =
std::max(currentDcpl, addOrMax(slot.aest, candidate->computeWeightOnCpu(this, c)));
if (lowerBound >= bestDcpl)
continue;
auto snapshot = dcp_graph::captureLocalScheduleState(
@@ -1437,37 +1169,23 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
updateAestFromTaskWithDescendants(candidate, llvm::ArrayRef<TaskDCP*>(relations.descendantsTopoOrder));
Time candidateDcpl = getDcpl();
taskInsertion.rollBack();
dcp_graph::restoreLocalScheduleState(snapshot, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
dcp_graph::restoreLocalScheduleState(
snapshot, dcpl, maxCompletion, secondMaxCompletion, maxCompletionTask);
if (candidateDcpl < bestDcpl) {
bestDcpl = candidateDcpl;
bestCpu = c;
bestSlot = slot;
}
}
if (bestCpu == -1)
llvm::report_fatal_error("DCP scheduler: no valid slot found for task on any eligible CPU — "
"all slots are blocked by already-placed descendants");
if (bestCpu == -1) {
bestCpu = 0;
bestSlot = {computeAestOnCpu(candidate, bestCpu), static_cast<int>(getOrCreateCpuTasks(bestCpu).size())};
}
}
}
if (bestCpu == getLastCpu() && getLastCpu() < maxCpuCount)
incrementLastCpu();
insertTaskInCPU(bestCpu, candidate, bestSlot.index);
// Incremental AEST/ALST refresh replacing the full initAest/initAlst that
// used to run after every placement. Post-insertion relations pick up any
// new scheduling-edge ancestors/descendants introduced by the insertion.
Time oldDcpl = getDcpl();
CandidateRelations postRelations = dcp_graph::computeCandidateRelations(candidate);
llvm::SmallVector<TaskDCP*, 32> postDescendantsTopoOrder;
postDescendantsTopoOrder.reserve(postRelations.descendants.size());
for (auto it = candidate->getTopologicalIterator(); it != topologicalOrder.end(); ++it) {
TaskDCP* current = &*it;
if (current != candidate && postRelations.descendants.contains(current))
postDescendantsTopoOrder.push_back(current);
}
updateAestFromTaskWithDescendants(candidate, llvm::ArrayRef<TaskDCP*>(postDescendantsTopoOrder));
updateAlstFromScheduledTask(candidate, postRelations, oldDcpl);
return postRelations;
}
//===----------------------------------------------------------------------===//
@@ -1476,99 +1194,61 @@ GraphDCP::CandidateRelations GraphDCP::selectProcessor(TaskDCP* candidate, bool
void GraphDCP::runDcp() {
initTopological();
initTaskStructureHashes();
initAest();
initAlst();
dumpDot();
dcp_graph::DcpProgressLogger progressLogger(nodes.size());
llvm::DenseMap<TaskDCP*, int> unscheduledParents;
// Min-heap over ready tasks: tightest slack first, earliest AEST as tiebreak.
// Lazy deletion: when AEST/ALST change after a placement, fresh entries are
// pushed for the affected tasks. Stale ones are detected on pop by comparing
// stored vs current (slack, aest) and re-pushed with the current values.
struct ReadyEntry {
Time slack;
Time aest;
TaskDCP* task;
bool operator>(const ReadyEntry& other) const {
if (slack != other.slack)
return slack > other.slack;
return aest > other.aest;
}
};
std::priority_queue<ReadyEntry, std::vector<ReadyEntry>, std::greater<ReadyEntry>> readyQueue;
size_t readyCount = 0;
auto pushReady = [&](TaskDCP* node) {
readyQueue.push({slackOrZero(node->getAest(), node->getAlst()), node->getAest(), node});
};
std::vector<TaskDCP*> readyNodes;
readyNodes.reserve(nodes.size());
for (auto& node : nodes) {
int dependencyParents = dcp_graph::countDependencyParents(&node);
unscheduledParents[&node] = dependencyParents;
if (dependencyParents == 0) {
pushReady(&node);
++readyCount;
}
if (dependencyParents == 0)
readyNodes.push_back(&node);
}
size_t xbarsCapacity = static_cast<size_t>(maxCpuCount) * onnx_mlir::crossbarCountInCore.getValue();
progressLogger.printStart(readyCount, maxCpuCount, xbarsCapacity);
progressLogger.printStart(readyNodes.size());
while (readyCount > 0) {
// Pop with lazy deletion: skip stale entries and re-push with current values.
TaskDCP* candidate = nullptr;
while (!readyQueue.empty()) {
auto entry = readyQueue.top();
readyQueue.pop();
Time curSlack = slackOrZero(entry.task->getAest(), entry.task->getAlst());
Time curAest = entry.task->getAest();
if (entry.slack == curSlack && entry.aest == curAest) {
candidate = entry.task;
break;
}
readyQueue.push({curSlack, curAest, entry.task});
}
assert(candidate != nullptr && "readyCount > 0 but heap exhausted");
--readyCount;
while (!readyNodes.empty()) {
DCP_DEBUG_IF(auto findStart = std::chrono::steady_clock::now();)
TaskDCP* candidate = findCandidate(readyNodes);
DCP_DEBUG_IF(progressLogger.recordFindDuration(
std::chrono::duration<double>(std::chrono::steady_clock::now() - findStart).count());)
fastRemove(readyNodes, candidate);
DCP_DEBUG_IF(auto selectStart = std::chrono::steady_clock::now();)
CandidateRelations postRelations = selectProcessor(candidate, candidate->isCriticalPath());
selectProcessor(candidate, candidate->isCriticalPath());
DCP_DEBUG_IF(
double selectSeconds = std::chrono::duration<double>(std::chrono::steady_clock::now() - selectStart).count();
progressLogger.recordSelectDuration(selectSeconds);
progressLogger.maybePrintSlowCandidate(getNodeIndex(candidate), selectSeconds, readyCount, getLastCpu());)
// Proactively refresh the heap priority for ready nodes whose AEST or ALST
// changed: ancestors had their ALST individually recomputed; descendants had
// their AEST bumped. Both may now sort differently than their stale entries.
for (TaskDCP* node : postRelations.ancestors)
if (!node->isScheduled() && unscheduledParents[node] == 0)
pushReady(node);
for (TaskDCP* node : postRelations.descendants)
if (!node->isScheduled() && unscheduledParents[node] == 0)
pushReady(node);
progressLogger.maybePrintSlowCandidate(getNodeIndex(candidate), selectSeconds, readyNodes.size(), getLastCpu());
)
DCP_DEBUG_IF(auto updateStart = std::chrono::steady_clock::now();)
initAest();
initAlst();
DCP_DEBUG_IF(progressLogger.recordUpdateDuration(
std::chrono::duration<double>(std::chrono::steady_clock::now() - updateStart).count());)
progressLogger.advanceCompleted();
progressLogger.printProgress(readyCount, getLastCpu(), maxCpuCount, crossbarsUsed(), crossbarsAvailable(), false);
progressLogger.printProgress(readyNodes.size(), getLastCpu(), "recompute", false);
for (const auto& childEdge : candidate->children) {
if (childEdge.isScheduling || childEdge.first->isScheduled())
continue;
int& dependencyParents = unscheduledParents[childEdge.first];
assert(dependencyParents > 0 && "dependency parent count must stay positive");
--dependencyParents;
if (dependencyParents == 0) {
pushReady(childEdge.first);
++readyCount;
}
dependencyParents--;
if (dependencyParents == 0)
readyNodes.push_back(childEdge.first);
}
DCP_DEBUG_IF(++gSelectTimers.tasksProcessed;
if (std::getenv("DCP_SELECT_PROFILE") && (gSelectTimers.tasksProcessed % 100 == 0))
gSelectTimers.dump("tick");)
DCP_DEBUG_IF(
++gSelectTimers.tasksProcessed;
if (std::getenv("DCP_SELECT_PROFILE") && (gSelectTimers.tasksProcessed % 100 == 0))
gSelectTimers.dump("tick");
)
}
progressLogger.printProgress(0, getLastCpu(), maxCpuCount, crossbarsUsed(), crossbarsAvailable(), true);
progressLogger.printProgress(readyNodes.size(), getLastCpu(), "done", true);
dumpDot();
}

View File

@@ -4,7 +4,6 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include <cstdint>
#include <list>
#include <optional>
#include <unordered_map>
@@ -49,10 +48,8 @@ private:
std::vector<TaskDCP> nodes;
onnx_mlir::LabeledList<TaskDCP> topologicalOrder;
std::vector<uint64_t> taskStructureHashes;
std::vector<CpuTaskList> cpuTasks;
std::unordered_map<CPU, CrossbarUsage> cpuCrossbarUsage;
llvm::DenseMap<CPU, uint64_t> cpuStructureHashes;
CPU lastCpu = 0;
long long flag = 1;
Time dcpl = 0;
@@ -73,7 +70,6 @@ private:
void initAest();
void initAlst();
void initTaskStructureHashes();
Time computeAestOnCpu(TaskDCP* task, CPU cpu);
Time computeDcplOnCpu(TaskDCP* task, CPU cpu);
@@ -87,15 +83,9 @@ private:
// `dcplBudget`, signalling that the new DCPL would exceed the budget.
// Returns true iff the full propagation completed without exceeding the
// budget. Uses the caller's snapshot to restore AEST on the aborted tail.
bool tryUpdateAestWithinBudget(TaskDCP* task, llvm::ArrayRef<TaskDCP*> descendantsTopoOrder, Time dcplBudget);
// Incrementally refreshes ALST after `task` has been scheduled. Nodes
// outside the backward cone (`relations.ancestors` plus `task`) retain
// their relative distance to the sink boundary and only absorb the signed
// DCPL delta (`newDcpl - oldDcpl`). `task` itself and its ancestors are
// recomputed in reverse topological order so that new same-CPU transfer
// costs (now zero) and scheduling-edge children are reflected.
void updateAlstFromScheduledTask(TaskDCP* task, const CandidateRelations& relations, Time oldDcpl);
bool tryUpdateAestWithinBudget(TaskDCP* task,
llvm::ArrayRef<TaskDCP*> descendantsTopoOrder,
Time dcplBudget);
void initTopological();
void topologicalMoveAfter(TaskDCP* task, TaskDCP* pivotPoint, TaskInsertion* insertion = nullptr);
@@ -104,11 +94,8 @@ private:
llvm::DenseMap<TaskDCP*, Time> computeAlst(TaskDCP* task, CPU cpu, const CandidateRelations& relations);
size_t getNodeIndex(const TaskDCP* task) const;
// Returns a compact dedup key for CPU `c` when evaluating `candidate`:
// mixes candidateAest, crossbar usage, and the incremental cpu structure
// hash into a single uint64_t. Zero heap allocation.
uint64_t computeCpuCandidateKey(Time candidateAest, CPU cpu);
CandidateRelations selectProcessor(TaskDCP* candidate, bool push);
TaskDCP* findCandidate(const std::vector<TaskDCP*>& readyNodes);
void selectProcessor(TaskDCP* candidate, bool push);
CPU getLastCpu() const { return lastCpu; }
void incrementLastCpu() { lastCpu++; }
FindSlot findSlot(TaskDCP* candidate, CPU cpu, bool push, const CandidateRelations& relations);
@@ -128,7 +115,8 @@ private:
public:
void runDcp();
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatCompute> spatComputes, llvm::ArrayRef<IndexedEdge> edges)
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatCompute> spatComputes,
llvm::ArrayRef<IndexedEdge> edges)
: nodes(), cpuTasks(), cpuCrossbarUsage() {
for (auto spatCompute : spatComputes)
nodes.emplace_back(spatCompute);
@@ -162,11 +150,6 @@ public:
void setMaxCpuCount(int value) { maxCpuCount = value; }
int getMaxCpuCount() const { return maxCpuCount; }
// Total crossbar units allocated across all active CPUs.
size_t crossbarsUsed() const;
// Maximum crossbar units available across all active CPUs (lastCpu * per-CPU capacity).
size_t crossbarsAvailable() const;
// Optional MLIR context used to drive mlir::parallelFor inside runDcp. If
// null the scheduler runs single-threaded (tests use this path).
void setContext(mlir::MLIRContext* ctx) { context = ctx; }

View File

@@ -35,12 +35,10 @@ void DcpProgressLogger::recordSelectDuration(double seconds) { selectProcessorSe
void DcpProgressLogger::recordUpdateDuration(double seconds) { updateTimingSeconds += seconds; }
void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += taskCount; }
void DcpProgressLogger::printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const {
void DcpProgressLogger::printStart(size_t readyCount) const {
if (!logProgress)
return;
llvm::errs() << llvm::formatv(
"[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n",
totalTasks, readyCount, maxCpuCount, xbarsCapacity);
llvm::errs() << llvm::formatv("[DCP] start: tasks={0} ready={1}\n", totalTasks, readyCount);
}
void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
@@ -50,15 +48,14 @@ void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex,
if (!logProgress || elapsedSeconds < 1.0)
return;
llvm::errs() << llvm::formatv("[DCP] slow node={0} elapsed={1} ready={2} cpus={3}\n",
llvm::errs() << llvm::formatv("[DCP] slow candidate node={0} elapsed={1} ready={2} cpus={3}\n",
nodeIndex,
formatDuration(elapsedSeconds),
readyCount,
cpuCount);
}
void DcpProgressLogger::printProgress(
size_t readyCount, CPU cpuCount, int maxCpuCount, size_t xbarsUsed, size_t xbarsAvailable, bool force) {
void DcpProgressLogger::printProgress(size_t readyCount, CPU cpuCount, llvm::StringRef stage, bool force) {
if (!logProgress)
return;
@@ -71,19 +68,19 @@ void DcpProgressLogger::printProgress(
double etaSeconds = rate > 0.0 ? static_cast<double>(totalTasks - completedTasks) / rate : 0.0;
double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast<double>(completedTasks) / totalTasks);
bool done = completedTasks == totalTasks;
llvm::errs() << llvm::formatv(
"[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n",
completedTasks,
totalTasks,
percent,
readyCount,
cpuCount,
maxCpuCount,
xbarsUsed,
xbarsAvailable,
llvm::formatv("elapsed={0}", formatDuration(elapsedSeconds)).str(),
done ? "" : llvm::formatv(" eta={0}", formatDuration(etaSeconds)).str());
llvm::errs() << llvm::formatv("[DCP] {0}/{1} ({2:F1}%) ready={3} cpus={4} stage={5} elapsed={6} eta={7}\n",
completedTasks,
totalTasks,
percent,
readyCount,
cpuCount,
stage,
formatDuration(elapsedSeconds),
completedTasks == totalTasks ? "0:00" : formatDuration(etaSeconds));
llvm::errs() << llvm::formatv(" time(find={0}, select={1}, update={2})\n",
formatDuration(findCandidateSeconds),
formatDuration(selectProcessorSeconds),
formatDuration(updateTimingSeconds));
lastProgressPrint = now;
}
@@ -94,9 +91,9 @@ void DcpProgressLogger::recordFindDuration(double) {}
void DcpProgressLogger::recordSelectDuration(double) {}
void DcpProgressLogger::recordUpdateDuration(double) {}
void DcpProgressLogger::advanceCompleted(size_t) {}
void DcpProgressLogger::printStart(size_t, int, size_t) const {}
void DcpProgressLogger::printStart(size_t) const {}
void DcpProgressLogger::maybePrintSlowCandidate(size_t, double, size_t, CPU) const {}
void DcpProgressLogger::printProgress(size_t, CPU, int, size_t, size_t, bool) {}
void DcpProgressLogger::printProgress(size_t, CPU, llvm::StringRef, bool) {}
#endif

View File

@@ -31,10 +31,9 @@ public:
void recordUpdateDuration(double seconds);
void advanceCompleted(size_t taskCount = 1);
void printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const;
void printStart(size_t readyCount) const;
void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const;
void printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount,
size_t xbarsUsed, size_t xbarsAvailable, bool force);
void printProgress(size_t readyCount, CPU cpuCount, llvm::StringRef stage, bool force);
#ifdef DCP_DEBUG_ENABLED
private:

View File

@@ -1,5 +1,4 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
@@ -9,32 +8,17 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <fstream>
#include <functional>
#include <iterator>
#include <limits>
#include <memory>
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
#include "DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
using namespace mlir;
@@ -42,629 +26,6 @@ namespace onnx_mlir {
namespace {
using SpatCompute = spatial::SpatCompute;
SpatCompute getOriginalSpatCompute(Operation* op) {
while (auto extract = dyn_cast_if_present<tensor::ExtractSliceOp>(op))
op = extract.getSource().getDefiningOp();
return dyn_cast_if_present<SpatCompute>(op);
}
struct ComputeMotifInfo {
uint64_t instructionCount = 0;
uint64_t weightedMvmCount = 0;
uint64_t weightedVmmCount = 0;
};
void appendUnique(SmallVector<size_t>& values, size_t value) {
if (!llvm::is_contained(values, value))
values.push_back(value);
}
bool isTrivialSerialMergeCandidate(SpatCompute compute) {
if (!compute->hasOneUse())
return false;
auto& use = *compute->getUses().begin();
auto user = dyn_cast<SpatCompute>(use.getOwner());
return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size();
}
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SpatCompute target, ValueRange sourceWeights) {
DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices;
for (auto [weightIndex, weight] : llvm::enumerate(target.getWeights()))
targetWeightIndices[weight].push_back(weightIndex);
DenseMap<Value, size_t> usedSourceWeightOccurrences;
SmallVector<size_t> sourceToTargetIndex;
sourceToTargetIndex.reserve(sourceWeights.size());
auto targetWeights = target.getWeightsMutable();
for (Value weight : sourceWeights) {
size_t occurrence = usedSourceWeightOccurrences[weight]++;
auto& matchingIndices = targetWeightIndices[weight];
if (occurrence >= matchingIndices.size()) {
size_t newIndex = target.getWeights().size();
targetWeights.append(weight);
matchingIndices.push_back(newIndex);
sourceToTargetIndex.push_back(newIndex);
continue;
}
sourceToTargetIndex.push_back(matchingIndices[occurrence]);
}
return sourceToTargetIndex;
}
void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(funcOp->getContext());
SmallVector<SpatCompute> trivialComputes;
llvm::SmallSet<SpatCompute, 8> toErase;
for (auto compute : funcOp.getOps<SpatCompute>())
if (isTrivialSerialMergeCandidate(compute))
trivialComputes.push_back(compute);
while (!trivialComputes.empty()) {
auto compute = trivialComputes.front();
if (compute.use_empty()) {
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
continue;
}
auto& computeUse = *compute->getUses().begin();
auto child = cast<SpatCompute>(computeUse.getOwner());
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
rewriter.setInsertionPointAfter(compute.getOperation());
auto newCompute = SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
IRMapping mapper;
SmallVector<size_t> childWeightToNewIndex = appendMissingWeightsAndBuildIndexMap(newCompute, child.getWeights());
for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights()))
mapper.map(weight, *std::next(newCompute.getWeights().begin(), childWeightToNewIndex[oldIndex]));
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
auto newTerminator = newCompute.getBody().front().getTerminator();
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
newTerminator->erase();
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
auto remapWeightIndex = [&](auto weightedOp) {
auto oldIndex = weightedOp.getWeightIndex();
assert(static_cast<size_t>(oldIndex) < childWeightToNewIndex.size() && "weight index out of range");
weightedOp.setWeightIndex(childWeightToNewIndex[oldIndex]);
};
for (auto& op : child.getBody().front()) {
auto newInst = rewriter.clone(op, mapper);
if (auto weightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(newInst))
remapWeightIndex(weightedMvmOp);
if (auto weightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(newInst))
remapWeightIndex(weightedVmmOp);
}
child.replaceAllUsesWith(newCompute);
toErase.insert(child);
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
toErase.insert(compute);
if (isTrivialSerialMergeCandidate(newCompute))
trivialComputes.push_back(newCompute);
}
for (auto compute : toErase) {
for (Value result : compute->getResults())
result.dropAllUses();
compute.erase();
}
}
struct WeightedVmmBandCandidate {
Operation* parent;
SpatCompute compute;
};
bool isSingleWeightedVmmCompute(SpatCompute compute) {
if (compute.getNumResults() != 1 || compute.getWeights().size() != 1 || compute.getInputs().size() != 1)
return false;
uint64_t weightedVmmCount = 0;
for (Operation& op : compute.getBody().front()) {
if (isa<spatial::SpatWeightedMVMOp>(&op))
return false;
if (isa<spatial::SpatWeightedVMMOp>(&op))
weightedVmmCount++;
}
return weightedVmmCount == 1;
}
std::optional<WeightedVmmBandCandidate> getWeightedVmmBandCandidate(SpatCompute compute) {
if (!isSingleWeightedVmmCompute(compute) || !compute->hasOneUse())
return std::nullopt;
auto& use = *compute->getUses().begin();
auto child = dyn_cast<SpatCompute>(use.getOwner());
if (!child || use.getOperandNumber() < child.getWeights().size())
return std::nullopt;
auto parent = getOriginalSpatCompute(compute.getInputs().front().getDefiningOp());
if (!parent || parent == child)
return std::nullopt;
return WeightedVmmBandCandidate {parent.getOperation(), compute};
}
bool haveSameWeightedVmmBandShape(SpatCompute lhs, SpatCompute rhs) {
return lhs.getWeights().front().getType() == rhs.getWeights().front().getType()
&& lhs.getInputs().front().getType() == rhs.getInputs().front().getType()
&& lhs.getResult(0).getType() == rhs.getResult(0).getType();
}
SpatCompute packWeightedVmmComputes(func::FuncOp funcOp, ArrayRef<SpatCompute> computes) {
assert(!computes.empty() && "expected at least one compute to pack");
IRRewriter rewriter(funcOp->getContext());
SpatCompute child = cast<SpatCompute>((*computes.front()->getUses().begin()).getOwner());
rewriter.setInsertionPoint(child);
SmallVector<Value> operands;
SmallVector<Type> inputTypes;
SmallVector<Location> inputLocs;
SmallVector<Type> resultTypes;
operands.reserve(computes.size() * 2);
inputTypes.reserve(computes.size());
inputLocs.reserve(computes.size());
resultTypes.reserve(computes.size());
for (SpatCompute compute : computes)
for (Value weight : compute.getWeights())
operands.push_back(weight);
for (SpatCompute compute : computes) {
for (Value input : compute.getInputs()) {
operands.push_back(input);
inputTypes.push_back(input.getType());
inputLocs.push_back(input.getLoc());
}
for (Type resultType : compute.getResultTypes())
resultTypes.push_back(resultType);
}
auto packedCompute = SpatCompute::create(rewriter, funcOp.getLoc(), resultTypes, operands);
packedCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(computes.size()), static_cast<int>(inputTypes.size())});
auto* block = rewriter.createBlock(&packedCompute.getBody(), packedCompute.getBody().end(), inputTypes, inputLocs);
rewriter.setInsertionPointToEnd(block);
SmallVector<Value> yieldValues;
yieldValues.reserve(resultTypes.size());
size_t inputBaseIndex = 0;
size_t weightBaseIndex = 0;
for (SpatCompute compute : computes) {
IRMapping mapper;
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
mapper.map(weight, *std::next(packedCompute.getWeights().begin(), weightBaseIndex + weightIndex));
for (auto [inputIndex, bbArg] : llvm::enumerate(compute.getBody().front().getArguments()))
mapper.map(bbArg, block->getArgument(inputBaseIndex + inputIndex));
auto remapWeightIndex = [&](auto weightedOp) {
weightedOp.setWeightIndex(weightBaseIndex + weightedOp.getWeightIndex());
};
for (Operation& op : compute.getBody().front()) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
for (Value yieldOperand : yield.getOperands())
yieldValues.push_back(mapper.lookup(yieldOperand));
continue;
}
Operation* cloned = rewriter.clone(op, mapper);
if (auto weightedMvmOp = dyn_cast<spatial::SpatWeightedMVMOp>(cloned))
remapWeightIndex(weightedMvmOp);
if (auto weightedVmmOp = dyn_cast<spatial::SpatWeightedVMMOp>(cloned))
remapWeightIndex(weightedVmmOp);
}
inputBaseIndex += compute.getInputs().size();
weightBaseIndex += compute.getWeights().size();
}
spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), yieldValues);
size_t resultIndex = 0;
for (SpatCompute compute : computes)
for (OpResult result : compute->getResults())
result.replaceAllUsesWith(packedCompute.getResult(resultIndex++));
for (SpatCompute compute : llvm::reverse(computes))
compute.erase();
return packedCompute;
}
size_t getFastPathCpuBudget() {
constexpr size_t kDefaultMaxCpuCount = 1000;
if (coresCount.getValue() > 0)
return static_cast<size_t>(coresCount.getValue());
return kDefaultMaxCpuCount;
}
size_t packWideWeightedVmmBands(func::FuncOp funcOp) {
constexpr size_t kMinBandSize = 64;
size_t cpuBudget = std::max<size_t>(1, getFastPathCpuBudget());
SmallVector<WeightedVmmBandCandidate> candidates;
for (SpatCompute compute : funcOp.getOps<SpatCompute>())
if (auto candidate = getWeightedVmmBandCandidate(compute))
candidates.push_back(*candidate);
llvm::stable_sort(candidates, [](const WeightedVmmBandCandidate& lhs, const WeightedVmmBandCandidate& rhs) {
if (lhs.parent != rhs.parent)
return lhs.parent < rhs.parent;
return lhs.compute->isBeforeInBlock(rhs.compute);
});
size_t packedBandCount = 0;
size_t packedNodeCount = 0;
size_t createdNodeCount = 0;
size_t minChunkSize = std::numeric_limits<size_t>::max();
size_t maxChunkSize = 0;
SmallVector<SmallVector<SpatCompute>> shapeGroups;
for (size_t parentBegin = 0; parentBegin < candidates.size();) {
size_t parentEnd = parentBegin + 1;
while (parentEnd < candidates.size() && candidates[parentEnd].parent == candidates[parentBegin].parent)
parentEnd++;
shapeGroups.clear();
for (size_t index = parentBegin; index < parentEnd; ++index) {
SpatCompute compute = candidates[index].compute;
auto groupIt = llvm::find_if(shapeGroups, [&](const SmallVector<SpatCompute>& group) {
return haveSameWeightedVmmBandShape(group.front(), compute);
});
if (groupIt == shapeGroups.end())
shapeGroups.push_back({compute});
else
groupIt->push_back(compute);
}
for (ArrayRef<SpatCompute> band : shapeGroups) {
size_t bandSize = band.size();
if (bandSize < kMinBandSize || bandSize <= cpuBudget)
continue;
size_t chunkSize = (bandSize + cpuBudget - 1) / cpuBudget;
packedBandCount++;
SmallVector<SpatCompute> chunk;
chunk.reserve(std::min(chunkSize, bandSize));
for (auto [index, compute] : llvm::enumerate(band)) {
chunk.push_back(compute);
if (chunk.size() == chunkSize || index + 1 == band.size()) {
minChunkSize = std::min(minChunkSize, chunk.size());
maxChunkSize = std::max(maxChunkSize, chunk.size());
packWeightedVmmComputes(funcOp, chunk);
packedNodeCount += chunk.size();
createdNodeCount++;
chunk.clear();
}
}
}
parentBegin = parentEnd;
}
if (packedNodeCount != 0)
llvm::errs() << llvm::formatv("[DCP-FASTPATH] wvmmBands={0} packedNodes={1} createdNodes={2} changed={3} "
"cpuBudget={4} chunkSizeRange={5}-{6}\n",
packedBandCount,
packedNodeCount,
createdNodeCount,
packedNodeCount - createdNodeCount,
cpuBudget,
minChunkSize,
maxChunkSize);
return packedNodeCount - createdNodeCount;
}
void emitMotifProfile(func::FuncOp funcOp) {
if (!std::getenv("DCP_MOTIF_PROFILE"))
return;
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
DenseMap<SpatCompute, size_t> computeToIndex;
computeToIndex.reserve(computes.size());
for (auto [index, compute] : llvm::enumerate(computes))
computeToIndex[compute] = index;
SmallVector<ComputeMotifInfo> computeInfos(computes.size());
SmallVector<SmallVector<size_t>> parents(computes.size());
SmallVector<SmallVector<size_t>> children(computes.size());
uint64_t weightedVmmNodeCount = 0;
uint64_t weightedVmmOpCount = 0;
uint64_t edgeCount = 0;
for (auto [index, compute] : llvm::enumerate(computes)) {
ComputeMotifInfo& info = computeInfos[index];
for (Operation& op : compute.getBody().front()) {
info.instructionCount++;
if (isa<spatial::SpatWeightedMVMOp>(&op))
info.weightedMvmCount++;
if (isa<spatial::SpatWeightedVMMOp>(&op))
info.weightedVmmCount++;
}
if (info.weightedVmmCount > 0) {
weightedVmmNodeCount++;
weightedVmmOpCount += info.weightedVmmCount;
}
for (Value input : compute.getInputs()) {
auto parent = getOriginalSpatCompute(input.getDefiningOp());
if (!parent || parent == compute)
continue;
auto parentIt = computeToIndex.find(parent);
if (parentIt == computeToIndex.end())
continue;
size_t parentIndex = parentIt->second;
size_t oldParentCount = parents[index].size();
appendUnique(parents[index], parentIndex);
if (parents[index].size() != oldParentCount) {
appendUnique(children[parentIndex], index);
edgeCount++;
}
}
}
uint64_t maxFanIn = 0;
uint64_t maxFanOut = 0;
uint64_t fanIn16 = 0;
uint64_t fanIn64 = 0;
uint64_t fanIn256 = 0;
uint64_t fanOut16 = 0;
uint64_t fanOut64 = 0;
uint64_t fanOut256 = 0;
for (size_t index = 0; index < computes.size(); ++index) {
uint64_t fanIn = parents[index].size();
uint64_t fanOut = children[index].size();
maxFanIn = std::max(maxFanIn, fanIn);
maxFanOut = std::max(maxFanOut, fanOut);
fanIn16 += fanIn >= 16;
fanIn64 += fanIn >= 64;
fanIn256 += fanIn >= 256;
fanOut16 += fanOut >= 16;
fanOut64 += fanOut >= 64;
fanOut256 += fanOut >= 256;
}
uint64_t serialChainCount = 0;
uint64_t serialChainNodeCount = 0;
uint64_t maxSerialChain = 0;
for (size_t index = 0; index < computes.size(); ++index) {
if (parents[index].size() == 1 && children[parents[index][0]].size() == 1)
continue;
uint64_t chainLength = 1;
size_t current = index;
while (children[current].size() == 1) {
size_t child = children[current][0];
if (parents[child].size() != 1)
break;
chainLength++;
current = child;
}
if (chainLength >= 2) {
serialChainCount++;
serialChainNodeCount += chainLength;
maxSerialChain = std::max(maxSerialChain, chainLength);
}
}
SmallVector<size_t> incomingEdgeCount;
incomingEdgeCount.reserve(parents.size());
for (ArrayRef<size_t> parentList : parents)
incomingEdgeCount.push_back(parentList.size());
SmallVector<uint64_t> level(computes.size(), 0);
SmallVector<size_t> readyNodes;
readyNodes.reserve(computes.size());
for (size_t index = 0; index < computes.size(); ++index)
if (incomingEdgeCount[index] == 0)
readyNodes.push_back(index);
size_t readyIndex = 0;
while (readyIndex != readyNodes.size()) {
size_t current = readyNodes[readyIndex++];
for (size_t child : children[current]) {
level[child] = std::max(level[child], level[current] + 1);
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
incomingEdgeCount[child]--;
if (incomingEdgeCount[child] == 0)
readyNodes.push_back(child);
}
}
SmallVector<uint64_t> weightedVmmNodesByLevel;
for (size_t index = 0; index < computes.size(); ++index) {
if (computeInfos[index].weightedVmmCount == 0)
continue;
if (weightedVmmNodesByLevel.size() <= level[index])
weightedVmmNodesByLevel.resize(level[index] + 1, 0);
weightedVmmNodesByLevel[level[index]]++;
}
uint64_t maxWeightedVmmLevel = 0;
uint64_t wideWeightedVmmLevels64 = 0;
uint64_t wideWeightedVmmLevels256 = 0;
for (uint64_t count : weightedVmmNodesByLevel) {
maxWeightedVmmLevel = std::max(maxWeightedVmmLevel, count);
wideWeightedVmmLevels64 += count >= 64;
wideWeightedVmmLevels256 += count >= 256;
}
using ShapeKey = std::tuple<uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t>;
SmallVector<ShapeKey> weightedVmmShapeKeys;
for (auto [index, compute] : llvm::enumerate(computes)) {
const ComputeMotifInfo& info = computeInfos[index];
if (info.weightedVmmCount == 0)
continue;
weightedVmmShapeKeys.push_back({info.instructionCount,
info.weightedVmmCount,
info.weightedMvmCount,
static_cast<uint64_t>(compute.getWeights().size()),
static_cast<uint64_t>(compute.getInputs().size()),
static_cast<uint64_t>(parents[index].size()),
static_cast<uint64_t>(children[index].size())});
}
llvm::sort(weightedVmmShapeKeys);
SmallVector<std::pair<uint64_t, ShapeKey>> weightedVmmShapeCounts;
for (size_t index = 0; index < weightedVmmShapeKeys.size();) {
size_t next = index + 1;
while (next < weightedVmmShapeKeys.size() && weightedVmmShapeKeys[next] == weightedVmmShapeKeys[index])
next++;
weightedVmmShapeCounts.push_back({next - index, weightedVmmShapeKeys[index]});
index = next;
}
llvm::sort(weightedVmmShapeCounts, [](const auto& lhs, const auto& rhs) {
if (lhs.first != rhs.first)
return lhs.first > rhs.first;
return lhs.second < rhs.second;
});
llvm::errs() << llvm::formatv("[DCP-MOTIF] computes={0} edges={1} wvmmNodes={2} wvmmOps={3} "
"serialChains={4} serialChainNodes={5} maxSerialChain={6} "
"maxFanIn={7} maxFanOut={8} fanIn>=16/64/256={9}/{10}/{11} "
"fanOut>=16/64/256={12}/{13}/{14} topoVisited={15}\n",
computes.size(),
edgeCount,
weightedVmmNodeCount,
weightedVmmOpCount,
serialChainCount,
serialChainNodeCount,
maxSerialChain,
maxFanIn,
maxFanOut,
fanIn16,
fanIn64,
fanIn256,
fanOut16,
fanOut64,
fanOut256,
readyNodes.size());
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmLevels={0} maxWvmmLevel={1} wideWvmmLevels>=64/256={2}/{3} "
"shapeGroups={4}\n",
weightedVmmNodesByLevel.size(),
maxWeightedVmmLevel,
wideWeightedVmmLevels64,
wideWeightedVmmLevels256,
weightedVmmShapeCounts.size());
for (size_t rank = 0, end = std::min<size_t>(weightedVmmShapeCounts.size(), 5); rank < end; ++rank) {
auto [count, shape] = weightedVmmShapeCounts[rank];
auto [insts, vmmOps, mvmOps, weights, inputs, fanIn, fanOut] = shape;
llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmShape rank={0} count={1} insts={2} vmmOps={3} "
"mvmOps={4} weights={5} inputs={6} fanIn={7} fanOut={8}\n",
rank,
count,
insts,
vmmOps,
mvmOps,
weights,
inputs,
fanIn,
fanOut);
}
}
void generateReport(func::FuncOp funcOp, const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/dcp_graph";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".txt", std::ios::out);
llvm::raw_os_ostream os(file);
uint64_t numSpatCompute = 0;
std::vector<std::tuple<uint64_t, uint64_t, uint64_t>> collectedData;
for (auto spatCompute : funcOp.getOps<SpatCompute>()) {
uint64_t numInst = 0;
for (auto& _ : spatCompute.getRegion().front())
numInst++;
collectedData.push_back({numSpatCompute++, spatCompute.getWeights().size(), numInst});
}
std::stable_sort(collectedData.begin(),
collectedData.end(),
[](std::tuple<uint64_t, uint64_t, uint64_t> lft, std::tuple<uint64_t, uint64_t, uint64_t> rgt) {
auto [iLft, weightLft, numInstLft] = lft;
auto [iRgt, weightRgt, numInstRgt] = rgt;
if (numInstLft < numInstRgt)
return false;
else if (numInstRgt < numInstLft)
return true;
if (weightLft < weightRgt)
return false;
else if (weightRgt < weightLft)
return true;
if (iLft < iRgt)
return true;
else if (iRgt < iLft)
return false;
return true;
});
for (uint64_t cI = 0; cI < numSpatCompute; ++cI) {
uint64_t lastIndex = cI;
auto [currentComputeId, currentWeight, currentNumInst] = collectedData[cI];
for (uint64_t nI = cI + 1; nI < numSpatCompute; ++nI) {
auto [nextComputeId, nextWeight, nextNumInst] = collectedData[nI];
if (currentWeight == nextWeight && currentNumInst == nextNumInst)
lastIndex = nI;
else
break;
}
os << "Compute " << currentComputeId;
auto expectedPrintedValue = currentComputeId + 1;
bool rangePrinted = false;
cI++;
for (; cI <= lastIndex; ++cI) {
auto candidateToPrint = std::get<0>(collectedData[cI]);
if (candidateToPrint == expectedPrintedValue) {
expectedPrintedValue = candidateToPrint + 1;
rangePrinted = true;
}
else {
if (rangePrinted)
os << " - " << expectedPrintedValue - 1;
os << " , " << candidateToPrint;
rangePrinted = false;
expectedPrintedValue = candidateToPrint + 1;
}
}
if (rangePrinted && currentComputeId != expectedPrintedValue - 1)
os << " - " << expectedPrintedValue - 1;
os << " :\n";
os << "\tNumber of instructions " << currentNumInst << "\n";
os << "\tNumber of used crossbars " << currentWeight << "\n";
cI = lastIndex;
}
os.flush();
file.close();
}
struct ComputeValueResults {
SmallVector<Value> innerValues;
@@ -684,7 +45,9 @@ public:
LazyInsertComputeResult(ComputeValueResults computeValueResults,
std::function<std::pair<Value, std::function<void(InsertPoint)>>(size_t)> channelNewInserter,
bool isOnlyChannel)
: computeResults(computeValueResults), onlyChannel(isOnlyChannel), channelNewInserter(channelNewInserter) {}
: computeResults(computeValueResults),
onlyChannel(isOnlyChannel),
channelNewInserter(channelNewInserter) {}
struct ChannelOrLocalOp {
Value data;
@@ -735,79 +98,45 @@ public:
LogicalResult initialize(MLIRContext* context) override { return success(); }
void verifyOrderAssumption(std::vector<spatial::SpatCompute>& dominanceOrderCompute) {
uint64_t computeNumber = 0;
llvm::DenseSet<SpatCompute> visited;
mlir::func::FuncOp funcOp = getOperation();
for (auto spatCompute : funcOp.getOps<SpatCompute>())
computeNumber++;
assert(computeNumber == dominanceOrderCompute.size());
for(auto domCompute : dominanceOrderCompute){
visited.insert(domCompute);
for(auto domInput : domCompute.getInputs() ){
if(auto domImputAsCompute = dyn_cast_if_present<SpatCompute>(domInput.getDefiningOp())){
assert(visited.contains(domImputAsCompute) && "Dominance order violated\n");
}
}
}
}
void runOnOperation() override {
mergeTriviallyConnectedComputes(getOperation());
packWideWeightedVmmBands(getOperation());
emitMotifProfile(getOperation());
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
auto& lastComputeOfCpu = analysisResult.isLastComputeOfCpu;
auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap;
func::FuncOp func = getOperation();
verifyOrderAssumption(analysisResult.dominanceOrderCompute);
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode);
if (!cpuToNewComputeMap.contains(cpu)) {
ValueTypeRange<ResultRange> newComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
auto [newCompute, computeValueResult] =
createNewComputeNode(currentComputeNode, newComputeType, lastComputeOfCpu.contains(currentComputeNode));
auto [newCompute, computeValueResult] = createNewComputeNode(
currentComputeNode, newComputeType, lastComputeOfCpu.contains(currentComputeNode));
cpuToNewComputeMap[cpu] = newCompute;
newComputeNodeResults.insert(std::make_pair(
currentComputeNode,
createLazyComputeResult(newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
newComputeNodeResults.insert(
std::make_pair(currentComputeNode,
createLazyComputeResult(
newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
}
else {
auto [newCompute, computeValueResult] = mergeIntoComputeNode(
cpuToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
newComputeNodeResults.insert(std::make_pair(
currentComputeNode,
createLazyComputeResult(newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
newComputeNodeResults.insert(
std::make_pair(currentComputeNode,
createLazyComputeResult(
newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
}
}
for (auto computeNodeToRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
if (!computeNodeToRemove->use_empty()) {
llvm::dbgs() << "Full module\n";
computeNodeToRemove->getParentOfType<ModuleOp>()->dump();
llvm::dbgs() << "Compute with uses:\n";
computeNodeToRemove.dump();
}
for (auto users : computeNodeToRemove->getUsers()) {
llvm::dbgs() << "Users:\n";
for (auto users : computeNodeToRemove->getUsers())
users->dump();
}
computeNodeToRemove.erase();
}
func::FuncOp func = getOperation();
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
generateReport(func, "spatial1_dcp_merged_report");
}
private:
std::pair<SpatCompute, ComputeValueResults>
createNewComputeNode(SpatCompute oldCompute, ValueTypeRange<ResultRange> newComputeType, bool lastCompute) {
std::pair<SpatCompute, ComputeValueResults> createNewComputeNode(
SpatCompute oldCompute, ValueTypeRange<ResultRange> newComputeType, bool lastCompute) {
func::FuncOp func = getOperation();
auto loc = func.getLoc();
IRRewriter rewriter(&getContext());
@@ -832,7 +161,8 @@ private:
auto newCompute = SpatCompute::create(rewriter, loc, newComputeType, newComputeOperand);
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newBBOperandType, newBBLocations);
rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), newBBOperandType, newBBLocations);
newCompute.getProperties().setOperandSegmentSizes(
{(int) oldCompute.getWeights().size(), (int) newBBOperandType.size()});
@@ -848,14 +178,15 @@ private:
mapper.map(oldBB.getArgument(indexOld - indexOldStart), newBB.getArgument(indexNew++));
}
else {
auto argWeightCompute = llvm::dyn_cast_if_present<SpatCompute>(oldCompute.getOperand(indexOld).getDefiningOp());
auto argWeightCompute =
llvm::dyn_cast_if_present<SpatCompute>(oldCompute.getOperand(indexOld).getDefiningOp());
auto argResultIndex = cast<OpResult>(oldCompute.getOperand(indexOld)).getResultNumber();
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender(argResultIndex);
assert(isChannel == true);
spatial::SpatChannelReceiveOp receiveOp =
spatial::SpatChannelReceiveOp::create(rewriter, loc, oldCompute.getOperand(indexOld).getType(), channelVal);
spatial::SpatChannelReceiveOp receiveOp = spatial::SpatChannelReceiveOp::create(
rewriter, loc, oldCompute.getOperand(indexOld).getType(), channelVal);
mapper.map(oldBB.getArgument(indexOld - indexOldStart), receiveOp);
}
}
@@ -889,34 +220,22 @@ private:
IRRewriter rewriter(&getContext());
IRMapping mapper;
DenseMap<Value, SmallVector<size_t, 4>> toWeightIndices;
for (auto [weightIndex, weight] : llvm::enumerate(toCompute.getWeights()))
toWeightIndices[weight].push_back(weightIndex);
DenseMap<Value, size_t> usedFromWeightOccurrences;
SmallVector<size_t> fromWeightToNewIndex;
fromWeightToNewIndex.reserve(fromCompute.getWeights().size());
auto weightMutableIter = toCompute.getWeightsMutable();
for (auto weight : fromCompute.getWeights()) {
size_t occurrence = usedFromWeightOccurrences[weight]++;
auto& matchingIndices = toWeightIndices[weight];
if (occurrence >= matchingIndices.size()) {
auto founded = llvm::find(toCompute.getWeights(), weight);
if (founded == toCompute.getWeights().end()) {
size_t sizeW = toCompute.getWeights().size();
size_t sizeI = toCompute.getInputs().size();
weightMutableIter.append(weight);
auto last = weightMutableIter.end();
last = std::prev(last, 1);
mapper.map(weight, last->get());
matchingIndices.push_back(sizeW);
fromWeightToNewIndex.push_back(sizeW);
assert(sizeW + 1 == toCompute.getWeights().size());
assert(sizeI == toCompute.getInputs().size());
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
}
else {
size_t newIndex = matchingIndices[occurrence];
mapper.map(weight, *std::next(toCompute.getWeights().begin(), newIndex));
fromWeightToNewIndex.push_back(newIndex);
mapper.map(weight, *founded);
}
}
@@ -968,8 +287,9 @@ private:
ComputeValueResults computeValueResults;
auto remapWeightIndex = [&](auto weightedOp) {
auto oldIndex = weightedOp.getWeightIndex();
assert(static_cast<size_t>(oldIndex) < fromWeightToNewIndex.size() && "weight index out of range");
weightedOp.setWeightIndex(fromWeightToNewIndex[oldIndex]);
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
weightedOp.setWeightIndex(newIndex);
};
for (auto& op : fromCompute.getOps()) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
@@ -998,8 +318,9 @@ private:
return {cast<SpatCompute>(toCompute), computeValueResults};
}
LazyInsertComputeResult
createLazyComputeResult(SpatCompute compute, ComputeValueResults computeValueResults, bool lastCompute) {
LazyInsertComputeResult createLazyComputeResult(SpatCompute compute,
ComputeValueResults computeValueResults,
bool lastCompute) {
func::FuncOp funcOp = cast<func::FuncOp>(compute->getParentOp());
auto* context = &getContext();
auto loc = funcOp.getLoc();
@@ -1014,12 +335,11 @@ private:
auto channelVal = channelOp.getResult();
auto insertVal =
[&context, loc, computeValueResults, channelVal, resultIndex](mlir::IRRewriter::InsertPoint sendInsertPoint) {
IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(sendInsertPoint);
auto spatSend =
spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.get(resultIndex));
return spatSend;
};
IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(sendInsertPoint);
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.get(resultIndex));
return spatSend;
};
std::pair<Value, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {channelVal, insertVal};
return ret;
};

View File

@@ -116,9 +116,10 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(coreOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
rewriter.setInsertionPoint(mapOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
pim::PimMemCopyOp::create(rewriter,
mapOp.getLoc(),

View File

@@ -1,3 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
add_custom_target(pim-unittest)
set_target_properties(pim-unittest PROPERTIES FOLDER "Tests")

View File

@@ -457,10 +457,6 @@ int testDCPGraphDiamondDependencies() {
return 0;
}
// crossbarSize=4, crossbarCount=2 => capacity = 4*4*2 = 32.
// Each task with crossbarUsage=1 needs footprint = 4*4 = 16, so at most 1 task
// can fit per CPU (16+16 = 32 >= capacity). The scheduler must open a fresh CPU
// for each task; all three end up on separate CPUs with their base weight.
int testDCPGraphCrossbarExhaustion() {
std::cout << "testDCPGraphCrossbarExhaustion:" << std::endl;
configureDcpDotOutput();
@@ -478,35 +474,36 @@ int testDCPGraphCrossbarExhaustion() {
const std::vector<Weight> nodeWeights = {10, 10, 10};
const std::vector<CrossbarUsage> nodeCrossbarUsage = {1, 1, 1};
GraphDCP graph(nodeWeights, {}, nodeCrossbarUsage);
graph.setMaxCpuCount(3);
graph.setMaxCpuCount(1);
graph.runDcp();
if (graph.cpuCount() != 3) {
if (graph.cpuCount() != 1) {
restoreCrossbarOptions();
std::cerr << "Expected 3 CPUs (one per task due to crossbar limit), got " << graph.cpuCount() << "\n";
std::cerr << "Expected exactly 1 CPU with maxCpuCount=1, got " << graph.cpuCount() << "\n";
dumpDcpFailureArtifacts();
return 1;
}
int failures = 0;
for (CPU c = 0; c < 3; c++) {
auto scheduledTasks = graph.getScheduledTasks(c);
if (scheduledTasks.size() != 1) {
std::cerr << "Expected exactly 1 task on CPU " << c << ", got " << scheduledTasks.size() << "\n";
printCpuSchedule(graph, c);
failures++;
continue;
}
if (scheduledTasks[0].weight != 10) {
std::cerr << "Expected weight=10 on CPU " << c << ", got " << scheduledTasks[0].weight << "\n";
printCpuSchedule(graph, c);
failures++;
}
auto scheduledTasks = graph.getScheduledTasks(0);
if (scheduledTasks.size() != 3) {
restoreCrossbarOptions();
std::cerr << "Expected all three tasks to be scheduled on CPU 0\n";
printCpuSchedule(graph, 0);
dumpDcpFailureArtifacts();
return 1;
}
if (scheduledTasks[0].weight != 10 || scheduledTasks[1].weight != std::numeric_limits<Weight>::max()
|| scheduledTasks[2].weight != std::numeric_limits<Weight>::max()) {
restoreCrossbarOptions();
std::cerr << "Unexpected effective weights under crossbar exhaustion\n";
printCpuSchedule(graph, 0);
dumpDcpFailureArtifacts();
return 1;
}
restoreCrossbarOptions();
if (failures) dumpDcpFailureArtifacts();
return failures;
return 0;
}
} // namespace