Add register reuse + peft scheduler cost model + Useless merger

This commit is contained in:
ilgeco
2026-06-18 10:56:57 +02:00
parent 852bef7605
commit e083c27d80
13 changed files with 350 additions and 20 deletions
@@ -19,9 +19,11 @@ using namespace mlir;
namespace onnx_mlir {
bool isWeightLikeComputeOperand(Value value) {
static bool isWeightMaterializationValue(Value value, bool requireMatrixShape) {
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
if (!rankedType)
return false;
if (requireMatrixShape && !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
@@ -29,8 +31,14 @@ bool isWeightLikeComputeOperand(Value value) {
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp)) {
auto sourceType = dyn_cast<RankedTensorType>(value.getType());
if (!sourceType)
return false;
if (requireMatrixShape && !isMatrixShape(sourceType.getShape()))
return false;
return true;
}
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
@@ -55,6 +63,8 @@ bool isWeightLikeComputeOperand(Value value) {
return false;
}
bool isWeightLikeComputeOperand(Value value) { return isWeightMaterializationValue(value, /*requireMatrixShape=*/true); }
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
@@ -91,7 +101,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
continue;
}
if (isWeightLikeComputeOperand(operand)) {
if (isWeightMaterializationValue(operand, /*requireMatrixShape=*/false)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();
@@ -26,6 +26,82 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) {
});
}
static bool isMaterializableExternalTensorOp(Operation* op) {
return isa<spatial::SpatChannelReceiveOp,
spatial::SpatExtractRowsOp,
tensor::ExtractSliceOp,
tensor::ExpandShapeOp,
tensor::CollapseShapeOp>(op);
}
//TODO REMOVE THIS UGLY FIX
//TODO: Remove this helper once compute_batch external tensor captures are
// fixed at the producer side.
//
// This function is a temporary SpatialToPim repair path. It clones selected
// external tensor producers, such as channel_receive and tensor view/slice ops,
// into the new pim.core_batch body when the old spat.compute_batch body refers
// to tensor values defined outside the batch.
//
// The real invariant should be stronger:
//
// A spat.compute_batch body must not capture external tensor values.
// Every tensor used inside the body must be either:
// - a compute_batch block argument,
// - defined inside the compute_batch body,
// - or a legal constant-like value.
//
// If this invariant is violated, the responsible producer, most likely merge
// schedule materialization, should emit verifier-clean Spatial IR instead of
// relying on SpatialToPim to clone external producer chains later.
//
// After that producer-side fix:
// 1. remove isMaterializableExternalTensorOp,
// 2. remove materializeExternalTensorValue,
// 3. make lowerComputeBatchOp emit a hard diagnostic for any unmapped external
// tensor operand,
// 4. keep/strengthen the Spatial verifier so the invalid capture is rejected
// before SpatialToPim.
//
// Be careful not to replace every external tensor capture with a normal
// compute_batch input blindly: host-backed tensors and explicit inter-core
// communication have different semantics. In particular, channel_receive-like
// values should be materialized through the communication model, not silently
// treated as host inputs.
static FailureOr<Value> materializeExternalTensorValue(IRRewriter& rewriter,
Location loc,
Block& oldBlock,
Value value,
IRMapping& mapper) {
if (mapper.contains(value))
return mapper.lookup(value);
if (!isa<TensorType>(value.getType()))
return value;
Operation* definingOp = value.getDefiningOp();
if (!definingOp || definingOp->hasTrait<OpTrait::ConstantLike>())
return failure();
if (definingOp->getBlock() == &oldBlock)
return failure();
if (!isMaterializableExternalTensorOp(definingOp))
return failure();
for (Value operand : definingOp->getOperands()) {
FailureOr<Value> materializedOperand = materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper);
if (succeeded(materializedOperand))
mapper.map(operand, *materializedOperand);
}
Operation* cloned = rewriter.clone(*definingOp, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
return mapper.lookup(value);
}
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp,
size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
@@ -264,9 +340,18 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
Operation* definingOp = operand.getDefiningOp();
if (definingOp && definingOp->getBlock() == &oldBlock)
continue;
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
return computeBatchOp.emitOpError(
"expected external tensor communication to be materialized in Spatial before batch lowering");
if (succeeded(materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper)))
continue;
InFlightDiagnostic diagnostic =
computeBatchOp.emitOpError("expected external tensor communication to be materialized in Spatial before batch lowering");
diagnostic << " while cloning nested op '" << op.getName() << "' tensor operand #" << operandIndex;
if (definingOp)
diagnostic << " from external producer '" << definingOp->getName() << "'";
return diagnostic;
}
Operation* cloned = rewriter.clone(op, mapper);