Add register reuse + peft scheduler cost model + Useless merger
This commit is contained in:
@@ -19,9 +19,11 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isWeightLikeComputeOperand(Value value) {
|
||||
static bool isWeightMaterializationValue(Value value, bool requireMatrixShape) {
|
||||
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!rankedType || !isMatrixShape(rankedType.getShape()))
|
||||
if (!rankedType)
|
||||
return false;
|
||||
if (requireMatrixShape && !isMatrixShape(rankedType.getShape()))
|
||||
return false;
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
@@ -29,8 +31,14 @@ bool isWeightLikeComputeOperand(Value value) {
|
||||
while (auto* definingOp = value.getDefiningOp()) {
|
||||
if (!visited.insert(definingOp).second)
|
||||
return false;
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp)) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!sourceType)
|
||||
return false;
|
||||
if (requireMatrixShape && !isMatrixShape(sourceType.getShape()))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
value = extractSliceOp.getSource();
|
||||
@@ -55,6 +63,8 @@ bool isWeightLikeComputeOperand(Value value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool isWeightLikeComputeOperand(Value value) { return isWeightMaterializationValue(value, /*requireMatrixShape=*/true); }
|
||||
|
||||
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
||||
if (auto mapped = mapper.lookupOrNull(value))
|
||||
return cast<Value>(mapped);
|
||||
@@ -91,7 +101,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isWeightLikeComputeOperand(operand)) {
|
||||
if (isWeightMaterializationValue(operand, /*requireMatrixShape=*/false)) {
|
||||
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
||||
if (failed(clonedOperand))
|
||||
return failure();
|
||||
|
||||
Reference in New Issue
Block a user