merge remote changes
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
@@ -11,6 +12,7 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
@@ -183,6 +185,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
encapsulateGlobalInstruction(*entryFunc);
|
||||
|
||||
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
||||
@@ -199,19 +202,36 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func
|
||||
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
||||
Value source = funcSource(toRemoveOp);
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
mapper.map(source, BB->getArgument(0));
|
||||
auto newInst = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
mapper.map(source, BB->getArgument(0));
|
||||
auto newInst = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present<tensor::ExtractSliceOp>(inst)) {
|
||||
auto source = toRemoveOp.getSource();
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
mapper.map(source, BB->getArgument(0));
|
||||
auto newInst = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -245,6 +265,24 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLoc;
|
||||
for (auto source : sources) {
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLoc.push_back(loc);
|
||||
}
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||
mapper.map(source, bbArg);
|
||||
auto newConcat = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -306,6 +344,89 @@ static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewrite
|
||||
return cast<Value>(mapped);
|
||||
}
|
||||
|
||||
bool sourceOpernadHasWeightAlways(Operation* op) {
|
||||
if (op == nullptr)
|
||||
return false;
|
||||
|
||||
Operation* source = nullptr;
|
||||
do {
|
||||
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch>(*op)) {
|
||||
return false;
|
||||
}
|
||||
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(*op)) {
|
||||
auto tmpSource = extractSliceOp.getSource();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(*op)) {
|
||||
auto tmpSource = extractRowsOp.getInput();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(*op)) {
|
||||
auto tmpSource = expandShapeOp.getSrc();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto transposeOp = dyn_cast<ONNXTransposeOp>(*op)) {
|
||||
auto tmpSource = transposeOp.getData();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(*op)) {
|
||||
auto tmpSource = collapseShapeOp.getSrc();
|
||||
auto definingOp = tmpSource.getDefiningOp();
|
||||
if (definingOp)
|
||||
op = definingOp;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
else if (auto constantOp = dyn_cast<arith::ConstantOp>(*op)) {
|
||||
source = constantOp;
|
||||
}
|
||||
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(*op)) {
|
||||
bool res = false;
|
||||
for (auto operand : concatOp.getOperands()) {
|
||||
res |= hasWeightAlways(operand.getDefiningOp());
|
||||
if (res)
|
||||
return res;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
else if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(*op)) {
|
||||
bool res = false;
|
||||
for (auto operand : concatOp.getOperands()) {
|
||||
res |= hasWeightAlways(operand.getDefiningOp());
|
||||
if (res)
|
||||
return res;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
else {
|
||||
op->dump();
|
||||
llvm_unreachable("Global instruction not handle in func");
|
||||
}
|
||||
}
|
||||
while (source == nullptr);
|
||||
|
||||
if (hasWeightAlways(source))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO what we want to keep in global?
|
||||
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
Location loc = funcOp.getLoc();
|
||||
@@ -314,8 +435,14 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
while (keep) {
|
||||
keep = false;
|
||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
||||
keep |= encapsulator<tensor::ExtractSliceOp>(
|
||||
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
|
||||
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
|
||||
instruction)
|
||||
|| isa<func::ReturnOp>(instruction)
|
||||
|| sourceOpernadHasWeightAlways(&instruction))
|
||||
continue;
|
||||
|
||||
keep |= encapsulateSlice(rewriter, loc, &instruction);
|
||||
|
||||
keep |= encapsulator<tensor::ExpandShapeOp>(
|
||||
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
|
||||
|
||||
Reference in New Issue
Block a user