better spat computes merging
All checks were successful
Validate Operations / validate-operations (push) Successful in 21m14s
All checks were successful
Validate Operations / validate-operations (push) Successful in 21m14s
This commit is contained in:
@@ -8,7 +8,6 @@
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
@@ -24,8 +23,6 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -52,7 +49,6 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
||||
private:
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
|
||||
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
||||
};
|
||||
|
||||
@@ -155,8 +151,6 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
mergeTriviallyConnectedComputes(*entryFunc);
|
||||
|
||||
// Dump to file for debug
|
||||
dumpModule(moduleOp, "spatial0");
|
||||
}
|
||||
@@ -187,8 +181,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
|
||||
auto sources = toRemoveOp.getInputs();
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
if (llvm::any_of(
|
||||
sources, [](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
||||
if (llvm::any_of(sources,
|
||||
[](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLoc;
|
||||
@@ -294,100 +288,6 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
Location loc = funcOp.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
SmallVector<spatial::SpatCompute> trivialComputes;
|
||||
llvm::SmallSet<spatial::SpatCompute, 8> toErase;
|
||||
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
|
||||
if (compute->hasOneUse()) {
|
||||
auto& use = *compute->getUses().begin();
|
||||
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
||||
|
||||
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
||||
trivialComputes.push_back(compute);
|
||||
}
|
||||
|
||||
while (!trivialComputes.empty()) {
|
||||
auto compute = trivialComputes.front();
|
||||
|
||||
if (compute.use_empty()) {
|
||||
std::swap(trivialComputes.front(), trivialComputes.back());
|
||||
trivialComputes.pop_back();
|
||||
continue;
|
||||
}
|
||||
auto& computeUse = *compute->getUses().begin();
|
||||
auto child = cast<spatial::SpatCompute>(computeUse.getOwner());
|
||||
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
|
||||
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
|
||||
|
||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
||||
|
||||
IRMapping mapper;
|
||||
auto weightMutableIter = newCompute.getWeightsMutable();
|
||||
for (auto weight : child.getWeights()) {
|
||||
auto founded = llvm::find(newCompute.getWeights(), weight);
|
||||
if (founded == newCompute.getWeights().end()) {
|
||||
weightMutableIter.append(weight);
|
||||
auto last = weightMutableIter.end();
|
||||
last = std::prev(last, 1);
|
||||
mapper.map(weight, last->get());
|
||||
}
|
||||
else {
|
||||
mapper.map(weight, *founded);
|
||||
}
|
||||
}
|
||||
|
||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
|
||||
newTerminator->erase();
|
||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||
for (auto& op : child.getBody().front()) {
|
||||
auto newInst = rewriter.clone(op, mapper);
|
||||
|
||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
|
||||
auto oldIndex = vmOp.getWeightIndex();
|
||||
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
|
||||
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
|
||||
vmOp.setWeightIndex(newIndex);
|
||||
}
|
||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
|
||||
auto oldIndex = vmOp.getWeightIndex();
|
||||
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
|
||||
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
|
||||
vmOp.setWeightIndex(newIndex);
|
||||
}
|
||||
}
|
||||
|
||||
child.replaceAllUsesWith(newCompute);
|
||||
toErase.insert(child);
|
||||
|
||||
std::swap(trivialComputes.front(), trivialComputes.back());
|
||||
trivialComputes.pop_back();
|
||||
toErase.insert(compute);
|
||||
|
||||
if (newCompute->hasOneUse()) {
|
||||
auto& use = *newCompute->getUses().begin();
|
||||
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
||||
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
||||
trivialComputes.push_back(newCompute);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto compute : toErase) {
|
||||
for (Value result : compute->getResults())
|
||||
result.dropAllUses();
|
||||
compute.erase();
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
||||
|
||||
@@ -96,8 +96,8 @@ bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
|
||||
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
|
||||
}
|
||||
|
||||
mlir::Value createPimReceiveFromSpatialChannel(
|
||||
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
|
||||
mlir::Value
|
||||
createPimReceiveFromSpatialChannel(PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
|
||||
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
|
||||
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
|
||||
@@ -127,6 +127,16 @@ SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
|
||||
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
|
||||
}
|
||||
|
||||
bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
|
||||
for (Operation* user : value.getUsers()) {
|
||||
if (user->getBlock() != operation->getBlock())
|
||||
return true;
|
||||
if (operation->isBeforeInBlock(user))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
|
||||
assert("Only support operations with a single result" && operation->getNumResults() == 1);
|
||||
mlir::Value result = operation->getResult(0);
|
||||
@@ -134,8 +144,9 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter,
|
||||
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
|
||||
|
||||
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
|
||||
auto validOperands =
|
||||
make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; });
|
||||
auto validOperands = make_filter_range(operands, [operation, resultType](mlir::Value operand) {
|
||||
return operand.getType() == resultType && !hasLaterUserInBlock(operand, operation);
|
||||
});
|
||||
auto bestOperand = validOperands.begin();
|
||||
|
||||
if (bestOperand != validOperands.end())
|
||||
|
||||
Reference in New Issue
Block a user