Bose
Validate Operations / validate-operations (push) Waiting to run

This commit is contained in:
ilgeco
2026-06-26 17:45:27 +02:00
parent 984f362623
commit 78e97f9fd8
23 changed files with 513 additions and 17489 deletions
@@ -44,17 +44,17 @@ static FailureOr<RowStripPhysicalValue> getRowStripValue(llvm::DenseMap<Value, R
return it->second;
}
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatReconciliatorOp reconciliator,
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatBlueprintOp blueprint,
Value physicalValue) {
auto logicalType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
auto logicalType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
if (!logicalType)
return reconciliator.emitOpError("requires ranked logical output type"), failure();
return blueprint.emitOpError("requires ranked logical output type"), failure();
RowStripPhysicalValue value;
value.physicalValue = physicalValue;
value.logicalType = logicalType;
value.fragmentOffsets.append(reconciliator.getFragmentOffsets().begin(), reconciliator.getFragmentOffsets().end());
value.fragmentSizes.append(reconciliator.getFragmentSizes().begin(), reconciliator.getFragmentSizes().end());
value.indexMap = reconciliator.getIndexMap().str();
value.fragmentOffsets.append(blueprint.getFragmentOffsets().begin(), blueprint.getFragmentOffsets().end());
value.fragmentSizes.append(blueprint.getFragmentSizes().begin(), blueprint.getFragmentSizes().end());
value.indexMap = blueprint.getIndexMap().str();
return value;
}
@@ -175,7 +175,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
auto verifyLogicalPhase = [&](StringRef stage) -> bool {
if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc)))
return true;
moduleOp.emitError() << "RAPTOR_PHASE_CHECK logical Spatial graph verification failed " << stage;
moduleOp.emitError() << "logical Spatial graph verification failed " << stage;
signalPassFailure();
return false;
};
@@ -185,11 +185,11 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
if (auto planOp = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
FailureOr<RowStripPhysicalValue> rowStripInput = getRowStripValue(rowStripValues, planOp.getInput());
auto rowStripReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
auto rowStripBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(user);
return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout;
});
if (rowStripReconciliator != planOp.getResult().getUsers().end()) {
if (rowStripBlueprint != planOp.getResult().getUsers().end()) {
rewriter.setInsertionPoint(planOp);
FailureOr<Value> lowered = lowerSelectedConv2DPlan(
planOp,
@@ -201,15 +201,15 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
signalPassFailure();
return;
}
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*rowStripReconciliator);
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(reconciliator, *lowered);
auto blueprint = cast<spatial::SpatBlueprintOp>(*rowStripBlueprint);
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(blueprint, *lowered);
if (failed(rowStripValue)) {
signalPassFailure();
return;
}
rowStripValues[reconciliator.getResult()] = *rowStripValue;
rowStripValues[blueprint.getResult()] = *rowStripValue;
eraseAfterLowering.insert(planOp);
eraseAfterLowering.insert(reconciliator);
eraseAfterLowering.insert(blueprint);
continue;
}
rewriter.setInsertionPoint(planOp);
@@ -226,12 +226,12 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
if (auto planOp = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) {
auto outputReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
auto outputBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(user);
return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout;
});
if (outputReconciliator == planOp.getResult().getUsers().end()) {
planOp.emitOpError("row-strip Relu plan requires a row-strip reconciliator result");
if (outputBlueprint == planOp.getResult().getUsers().end()) {
planOp.emitOpError("row-strip Relu plan requires a row-strip blueprint result");
signalPassFailure();
return;
}
@@ -244,15 +244,15 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
signalPassFailure();
return;
}
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*outputReconciliator);
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(reconciliator, *lowered);
auto blueprint = cast<spatial::SpatBlueprintOp>(*outputBlueprint);
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(blueprint, *lowered);
if (failed(output)) {
signalPassFailure();
return;
}
rowStripValues[reconciliator.getResult()] = *output;
rowStripValues[blueprint.getResult()] = *output;
eraseAfterLowering.insert(planOp);
eraseAfterLowering.insert(reconciliator);
eraseAfterLowering.insert(blueprint);
continue;
}
@@ -279,7 +279,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
}
FailureOr<RowStripPhysicalValue> rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput());
if (failed(rowStripValue)) {
materializeOp.emitOpError("expected a row-strip reconciliator input during row-strip materialization");
materializeOp.emitOpError("expected a row-strip blueprint input during row-strip materialization");
signalPassFailure();
return;
}
@@ -293,18 +293,18 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
rewriter.replaceOp(materializeOp, *dense);
continue;
}
if (auto reconciliatorOp = dyn_cast<spatial::SpatReconciliatorOp>(&op)) {
if (reconciliatorOp.getPhysicalLayout() == kDenseLayout) {
rewriter.replaceOp(reconciliatorOp, reconciliatorOp.getInput());
if (auto blueprintOp = dyn_cast<spatial::SpatBlueprintOp>(&op)) {
if (blueprintOp.getPhysicalLayout() == kDenseLayout) {
rewriter.replaceOp(blueprintOp, blueprintOp.getInput());
continue;
}
if (reconciliatorOp.getPhysicalLayout() != kRowStripLayout) {
reconciliatorOp.emitOpError("non-dense reconciliator lowering is not supported yet");
if (blueprintOp.getPhysicalLayout() != kRowStripLayout) {
blueprintOp.emitOpError("non-dense blueprint lowering is not supported yet");
signalPassFailure();
return;
}
if (!eraseAfterLowering.contains(reconciliatorOp)) {
reconciliatorOp.emitOpError("unhandled row-strip reconciliator remained during LowerSpatialPlans");
if (!eraseAfterLowering.contains(blueprintOp)) {
blueprintOp.emitOpError("unhandled row-strip blueprint remained during LowerSpatialPlans");
signalPassFailure();
return;
}
@@ -385,7 +385,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
return;
if (isa<spatial::SpatConv2DPlanOp,
spatial::SpatReluPlanOp,
spatial::SpatReconciliatorOp,
spatial::SpatBlueprintOp,
spatial::SpatMaterializeLayoutOp>(op)
|| op->getDialect()->getNamespace() == "onnx") {
op->emitOpError("operation must not remain after LowerSpatialPlans");