Refactor ONNXToSpatial Common and diagnostics
This commit is contained in:
@@ -12,14 +12,13 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "Common/Common.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
@@ -32,8 +31,6 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool haveSameStaticShape(Value lhs, Value rhs);
|
||||
|
||||
namespace {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||
@@ -50,7 +47,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
||||
|
||||
private:
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||
LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
||||
};
|
||||
|
||||
@@ -186,7 +183,10 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
encapsulateGlobalInstruction(*entryFunc);
|
||||
if (failed(encapsulateGlobalInstruction(*entryFunc))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
||||
signalPassFailure();
|
||||
@@ -287,64 +287,7 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
||||
if (auto mapped = mapper.lookupOrNull(value))
|
||||
return cast<Value>(mapped);
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!tensorType || !tensorType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
|
||||
sizes.reserve(tensorType.getRank());
|
||||
for (int64_t dim : tensorType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
|
||||
auto referencedValue =
|
||||
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
|
||||
mapper.map(value, referencedValue.getResult());
|
||||
return referencedValue.getResult();
|
||||
}
|
||||
|
||||
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
||||
return failure();
|
||||
|
||||
IRMapping localMapper;
|
||||
for (Value operand : definingOp->getOperands()) {
|
||||
if (auto mapped = mapper.lookupOrNull(operand)) {
|
||||
localMapper.map(operand, cast<Value>(mapped));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isWeightLikeComputeOperand(operand)) {
|
||||
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
||||
if (failed(clonedOperand))
|
||||
return failure();
|
||||
localMapper.map(operand, *clonedOperand);
|
||||
continue;
|
||||
}
|
||||
|
||||
localMapper.map(operand, operand);
|
||||
}
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
|
||||
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
|
||||
auto mapped = mapper.lookupOrNull(value);
|
||||
if (!mapped)
|
||||
return failure();
|
||||
return cast<Value>(mapped);
|
||||
}
|
||||
|
||||
bool sourceOpernadHasWeightAlways(Operation* op) {
|
||||
static FailureOr<bool> sourceOperandHasWeightAlways(Operation* op) {
|
||||
if (op == nullptr)
|
||||
return false;
|
||||
|
||||
@@ -416,30 +359,32 @@ bool sourceOpernadHasWeightAlways(Operation* op) {
|
||||
return res;
|
||||
}
|
||||
else {
|
||||
op->dump();
|
||||
llvm_unreachable("Global instruction not handle in func");
|
||||
op->emitOpError("unsupported global instruction while promoting weight-backed operands into Spatial computes");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
while (source == nullptr);
|
||||
|
||||
if (hasWeightAlways(source))
|
||||
return true;
|
||||
return false;
|
||||
return hasWeightAlways(source);
|
||||
}
|
||||
|
||||
// TODO what we want to keep in global?
|
||||
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
Location loc = funcOp.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
bool keep = true;
|
||||
while (keep) {
|
||||
keep = false;
|
||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
||||
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
|
||||
instruction)
|
||||
|| isa<func::ReturnOp>(instruction)
|
||||
|| sourceOpernadHasWeightAlways(&instruction))
|
||||
|| isa<func::ReturnOp>(instruction))
|
||||
continue;
|
||||
|
||||
auto weightBacked = sourceOperandHasWeightAlways(&instruction);
|
||||
if (failed(weightBacked))
|
||||
return failure();
|
||||
if (*weightBacked)
|
||||
continue;
|
||||
|
||||
keep |= encapsulateSlice(rewriter, loc, &instruction);
|
||||
@@ -456,6 +401,7 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
keep |= encapsulateConcat(rewriter, loc, &instruction);
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
|
||||
Reference in New Issue
Block a user