fix mergeTriviallyConnectedComputes
All checks were successful
Validate Operations / validate-operations (push) Successful in 16m48s
All checks were successful
Validate Operations / validate-operations (push) Successful in 16m48s
This commit is contained in:
@@ -8,6 +8,7 @@
|
|||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallSet.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
@@ -22,8 +23,8 @@
|
|||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
@@ -51,7 +52,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
|||||||
private:
|
private:
|
||||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||||
void mergeSingleChildCompute(func::FuncOp funcOp);
|
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -148,7 +149,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
|
|
||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
encapsulateGlobalInstruction(*entryFunc);
|
encapsulateGlobalInstruction(*entryFunc);
|
||||||
mergeSingleChildCompute(*entryFunc);
|
mergeTriviallyConnectedComputes(*entryFunc);
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
dumpModule(moduleOp, "spatial0");
|
dumpModule(moduleOp, "spatial0");
|
||||||
@@ -230,66 +231,61 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ONNXToSpatialPass::mergeSingleChildCompute(func::FuncOp funcOp) {
|
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||||
llvm::SmallVector<spatial::SpatWeightedCompute> computeSingleChild;
|
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
IRRewriter rewriter(&getContext());
|
IRRewriter rewriter(&getContext());
|
||||||
|
SmallVector<spatial::SpatWeightedCompute> trivialComputes;
|
||||||
|
llvm::SmallSet<spatial::SpatWeightedCompute, 8> toErase;
|
||||||
|
|
||||||
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
||||||
if (std::distance(compute->getUses().begin(), compute->getUses().end()) == 1) {
|
if (compute->hasOneUse()) {
|
||||||
auto user = *compute->getUsers().begin();
|
auto user = *compute->getUsers().begin();
|
||||||
if (user->getNumOperands() == 1)
|
if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
|
||||||
if (llvm::isa<spatial::SpatWeightedCompute>(user))
|
trivialComputes.push_back(compute);
|
||||||
computeSingleChild.push_back(compute);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
IRMapping mapper;
|
while (!trivialComputes.empty()) {
|
||||||
while (!computeSingleChild.empty()) {
|
auto compute = trivialComputes.front();
|
||||||
auto compute = computeSingleChild.front();
|
|
||||||
auto child = dyn_cast_if_present<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
if (compute.use_empty()) {
|
||||||
assert(child && "Child required!");
|
std::swap(trivialComputes.front(), trivialComputes.back());
|
||||||
|
trivialComputes.pop_back();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||||
auto newCompute =
|
auto newCompute =
|
||||||
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
{(int) compute.getWeights().size(), (int) compute.getInputs().size()});
|
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
||||||
llvm::dbgs() << "After Creation\n";
|
|
||||||
newCompute.dump();
|
|
||||||
|
|
||||||
|
IRMapping mapper;
|
||||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||||
llvm::dbgs() << "After Clone\n";
|
|
||||||
newCompute.dump();
|
|
||||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||||
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
||||||
newTerminator->erase();
|
newTerminator->erase();
|
||||||
llvm::dbgs() << "After terminator\n";
|
|
||||||
newCompute.dump();
|
|
||||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||||
|
|
||||||
for (auto& op : child.getBody().front())
|
for (auto& op : child.getBody().front())
|
||||||
rewriter.clone(op, mapper);
|
rewriter.clone(op, mapper);
|
||||||
|
|
||||||
child.replaceAllUsesWith(newCompute);
|
child.replaceAllUsesWith(newCompute);
|
||||||
assert(child->getUses().empty() && "It's not obvius");
|
toErase.insert(child);
|
||||||
llvm::dbgs() << "Node\n";
|
|
||||||
newCompute.dump();
|
|
||||||
|
|
||||||
llvm::dbgs() << "Parent\n";
|
std::swap(trivialComputes.front(), trivialComputes.back());
|
||||||
compute.dump();
|
trivialComputes.pop_back();
|
||||||
|
toErase.insert(compute);
|
||||||
|
|
||||||
llvm::dbgs() << "Child\n";
|
if (newCompute->hasOneUse()) {
|
||||||
child.dump();
|
|
||||||
|
|
||||||
child.erase();
|
|
||||||
compute.erase();
|
|
||||||
|
|
||||||
if (std::distance(newCompute->getUses().begin(), newCompute->getUses().end()) == 1) {
|
|
||||||
auto user = *newCompute->getUsers().begin();
|
auto user = *newCompute->getUsers().begin();
|
||||||
if (user->getNumOperands() == 1)
|
if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
|
||||||
if (llvm::isa<spatial::SpatWeightedCompute>(user))
|
trivialComputes.push_back(newCompute);
|
||||||
computeSingleChild.push_back(newCompute);
|
|
||||||
}
|
}
|
||||||
std::swap(computeSingleChild.front(), computeSingleChild.back());
|
}
|
||||||
computeSingleChild.pop_back();
|
|
||||||
|
for (auto compute : toErase) {
|
||||||
|
compute.getResult(0).dropAllUses();
|
||||||
|
compute.erase();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user