better spat computes merging
All checks were successful
Validate Operations / validate-operations (push) Successful in 21m14s

This commit is contained in:
NiccoloN
2026-04-25 19:24:09 +02:00
parent 951baca106
commit 15e8edb9c4
11 changed files with 1477 additions and 369 deletions

View File

@@ -8,7 +8,6 @@
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
@@ -24,8 +23,6 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -52,7 +49,6 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
void encapsulateGlobalInstruction(func::FuncOp funcOp);
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
};
@@ -155,8 +151,6 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
mergeTriviallyConnectedComputes(*entryFunc);
// Dump to file for debug
dumpModule(moduleOp, "spatial0");
}
@@ -187,8 +181,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
auto sources = toRemoveOp.getInputs();
rewriter.setInsertionPointAfter(toRemoveOp);
if (llvm::any_of(
sources, [](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
if (llvm::any_of(sources,
[](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc;
@@ -294,100 +288,6 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
}
}
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatCompute> trivialComputes;
llvm::SmallSet<spatial::SpatCompute, 8> toErase;
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
if (compute->hasOneUse()) {
auto& use = *compute->getUses().begin();
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
trivialComputes.push_back(compute);
}
while (!trivialComputes.empty()) {
auto compute = trivialComputes.front();
if (compute.use_empty()) {
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
continue;
}
auto& computeUse = *compute->getUses().begin();
auto child = cast<spatial::SpatCompute>(computeUse.getOwner());
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
rewriter.setInsertionPointAfter(compute.getOperation());
auto newCompute =
spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
IRMapping mapper;
auto weightMutableIter = newCompute.getWeightsMutable();
for (auto weight : child.getWeights()) {
auto founded = llvm::find(newCompute.getWeights(), weight);
if (founded == newCompute.getWeights().end()) {
weightMutableIter.append(weight);
auto last = weightMutableIter.end();
last = std::prev(last, 1);
mapper.map(weight, last->get());
}
else {
mapper.map(weight, *founded);
}
}
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
auto newTerminator = newCompute.getBody().front().getTerminator();
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
newTerminator->erase();
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
for (auto& op : child.getBody().front()) {
auto newInst = rewriter.clone(op, mapper);
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
}
child.replaceAllUsesWith(newCompute);
toErase.insert(child);
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
toErase.insert(compute);
if (newCompute->hasOneUse()) {
auto& use = *newCompute->getUses().begin();
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
trivialComputes.push_back(newCompute);
}
}
for (auto compute : toErase) {
for (Value result : compute->getResults())
result.dropAllUses();
compute.erase();
}
}
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) {
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))

View File

@@ -96,8 +96,8 @@ bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
}
mlir::Value createPimReceiveFromSpatialChannel(
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
mlir::Value
createPimReceiveFromSpatialChannel(PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
@@ -127,6 +127,16 @@ SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
}
bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
for (Operation* user : value.getUsers()) {
if (user->getBlock() != operation->getBlock())
return true;
if (operation->isBeforeInBlock(user))
return true;
}
return false;
}
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
assert("Only support operations with a single result" && operation->getNumResults() == 1);
mlir::Value result = operation->getResult(0);
@@ -134,8 +144,9 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter,
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
auto validOperands =
make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; });
auto validOperands = make_filter_range(operands, [operation, resultType](mlir::Value operand) {
return operand.getType() == resultType && !hasLaterUserInBlock(operand, operation);
});
auto bestOperand = validOperands.begin();
if (bestOperand != validOperands.end())