Add register reuse + peft scheduler cost model + Useless merger

This commit is contained in:
ilgeco
2026-06-18 10:56:57 +02:00
parent 852bef7605
commit e083c27d80
13 changed files with 350 additions and 20 deletions
@@ -19,9 +19,11 @@ using namespace mlir;
namespace onnx_mlir {
bool isWeightLikeComputeOperand(Value value) {
static bool isWeightMaterializationValue(Value value, bool requireMatrixShape) {
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
if (!rankedType)
return false;
if (requireMatrixShape && !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
@@ -29,8 +31,14 @@ bool isWeightLikeComputeOperand(Value value) {
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp)) {
auto sourceType = dyn_cast<RankedTensorType>(value.getType());
if (!sourceType)
return false;
if (requireMatrixShape && !isMatrixShape(sourceType.getShape()))
return false;
return true;
}
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
@@ -55,6 +63,8 @@ bool isWeightLikeComputeOperand(Value value) {
return false;
}
bool isWeightLikeComputeOperand(Value value) { return isWeightMaterializationValue(value, /*requireMatrixShape=*/true); }
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
@@ -91,7 +101,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
continue;
}
if (isWeightLikeComputeOperand(operand)) {
if (isWeightMaterializationValue(operand, /*requireMatrixShape=*/false)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();