MergeDCP pass all test
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h55m12s
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h55m12s
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
@@ -17,6 +18,19 @@ namespace spatial {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
SpatWeightedCompute getOriginalSpatWeightCompute(Operation* op) {
|
||||
if (!op)
|
||||
return {};
|
||||
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
op = extract.getSource().getDefiningOp();
|
||||
if (!op)
|
||||
return {};
|
||||
}
|
||||
if (auto res = llvm::dyn_cast<SpatWeightedCompute>(op))
|
||||
return res;
|
||||
return {};
|
||||
}
|
||||
|
||||
DCPAnalysisResult DCPAnalysis::runAnalysis() {
|
||||
using EdgesIndex = std::tuple<int64_t, int64_t, int64_t>;
|
||||
llvm::SmallVector<SpatWeightedCompute, 10> spatWeightedComputes;
|
||||
@@ -27,8 +41,7 @@ DCPAnalysisResult DCPAnalysis::runAnalysis() {
|
||||
|
||||
for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
|
||||
for (Value input : spatWeightedCompute.getInputs()) {
|
||||
if (auto spatWeightedComputeArgOp = llvm::dyn_cast_if_present<SpatWeightedCompute>(input.getDefiningOp());
|
||||
spatWeightedComputeArgOp) {
|
||||
if (auto spatWeightedComputeArgOp = getOriginalSpatWeightCompute(input.getDefiningOp())) {
|
||||
auto elemIter = llvm::find(spatWeightedComputes, spatWeightedComputeArgOp);
|
||||
assert(elemIter != spatWeightedComputes.end());
|
||||
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), elemIter);
|
||||
|
||||
Reference in New Issue
Block a user