@@ -44,17 +44,17 @@ static FailureOr<RowStripPhysicalValue> getRowStripValue(llvm::DenseMap<Value, R
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatReconciliatorOp reconciliator,
|
||||
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatBlueprintOp blueprint,
|
||||
Value physicalValue) {
|
||||
auto logicalType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
|
||||
auto logicalType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||
if (!logicalType)
|
||||
return reconciliator.emitOpError("requires ranked logical output type"), failure();
|
||||
return blueprint.emitOpError("requires ranked logical output type"), failure();
|
||||
RowStripPhysicalValue value;
|
||||
value.physicalValue = physicalValue;
|
||||
value.logicalType = logicalType;
|
||||
value.fragmentOffsets.append(reconciliator.getFragmentOffsets().begin(), reconciliator.getFragmentOffsets().end());
|
||||
value.fragmentSizes.append(reconciliator.getFragmentSizes().begin(), reconciliator.getFragmentSizes().end());
|
||||
value.indexMap = reconciliator.getIndexMap().str();
|
||||
value.fragmentOffsets.append(blueprint.getFragmentOffsets().begin(), blueprint.getFragmentOffsets().end());
|
||||
value.fragmentSizes.append(blueprint.getFragmentSizes().begin(), blueprint.getFragmentSizes().end());
|
||||
value.indexMap = blueprint.getIndexMap().str();
|
||||
return value;
|
||||
}
|
||||
|
||||
@@ -175,7 +175,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
auto verifyLogicalPhase = [&](StringRef stage) -> bool {
|
||||
if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc)))
|
||||
return true;
|
||||
moduleOp.emitError() << "RAPTOR_PHASE_CHECK logical Spatial graph verification failed " << stage;
|
||||
moduleOp.emitError() << "logical Spatial graph verification failed " << stage;
|
||||
signalPassFailure();
|
||||
return false;
|
||||
};
|
||||
@@ -185,11 +185,11 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||
if (auto planOp = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
|
||||
FailureOr<RowStripPhysicalValue> rowStripInput = getRowStripValue(rowStripValues, planOp.getInput());
|
||||
auto rowStripReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
|
||||
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
|
||||
auto rowStripBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(user);
|
||||
return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout;
|
||||
});
|
||||
if (rowStripReconciliator != planOp.getResult().getUsers().end()) {
|
||||
if (rowStripBlueprint != planOp.getResult().getUsers().end()) {
|
||||
rewriter.setInsertionPoint(planOp);
|
||||
FailureOr<Value> lowered = lowerSelectedConv2DPlan(
|
||||
planOp,
|
||||
@@ -201,15 +201,15 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*rowStripReconciliator);
|
||||
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(reconciliator, *lowered);
|
||||
auto blueprint = cast<spatial::SpatBlueprintOp>(*rowStripBlueprint);
|
||||
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(blueprint, *lowered);
|
||||
if (failed(rowStripValue)) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
rowStripValues[reconciliator.getResult()] = *rowStripValue;
|
||||
rowStripValues[blueprint.getResult()] = *rowStripValue;
|
||||
eraseAfterLowering.insert(planOp);
|
||||
eraseAfterLowering.insert(reconciliator);
|
||||
eraseAfterLowering.insert(blueprint);
|
||||
continue;
|
||||
}
|
||||
rewriter.setInsertionPoint(planOp);
|
||||
@@ -226,12 +226,12 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
|
||||
if (auto planOp = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
|
||||
if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) {
|
||||
auto outputReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
|
||||
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
|
||||
auto outputBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(user);
|
||||
return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout;
|
||||
});
|
||||
if (outputReconciliator == planOp.getResult().getUsers().end()) {
|
||||
planOp.emitOpError("row-strip Relu plan requires a row-strip reconciliator result");
|
||||
if (outputBlueprint == planOp.getResult().getUsers().end()) {
|
||||
planOp.emitOpError("row-strip Relu plan requires a row-strip blueprint result");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -244,15 +244,15 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*outputReconciliator);
|
||||
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(reconciliator, *lowered);
|
||||
auto blueprint = cast<spatial::SpatBlueprintOp>(*outputBlueprint);
|
||||
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(blueprint, *lowered);
|
||||
if (failed(output)) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
rowStripValues[reconciliator.getResult()] = *output;
|
||||
rowStripValues[blueprint.getResult()] = *output;
|
||||
eraseAfterLowering.insert(planOp);
|
||||
eraseAfterLowering.insert(reconciliator);
|
||||
eraseAfterLowering.insert(blueprint);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -279,7 +279,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
}
|
||||
FailureOr<RowStripPhysicalValue> rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput());
|
||||
if (failed(rowStripValue)) {
|
||||
materializeOp.emitOpError("expected a row-strip reconciliator input during row-strip materialization");
|
||||
materializeOp.emitOpError("expected a row-strip blueprint input during row-strip materialization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -293,18 +293,18 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
rewriter.replaceOp(materializeOp, *dense);
|
||||
continue;
|
||||
}
|
||||
if (auto reconciliatorOp = dyn_cast<spatial::SpatReconciliatorOp>(&op)) {
|
||||
if (reconciliatorOp.getPhysicalLayout() == kDenseLayout) {
|
||||
rewriter.replaceOp(reconciliatorOp, reconciliatorOp.getInput());
|
||||
if (auto blueprintOp = dyn_cast<spatial::SpatBlueprintOp>(&op)) {
|
||||
if (blueprintOp.getPhysicalLayout() == kDenseLayout) {
|
||||
rewriter.replaceOp(blueprintOp, blueprintOp.getInput());
|
||||
continue;
|
||||
}
|
||||
if (reconciliatorOp.getPhysicalLayout() != kRowStripLayout) {
|
||||
reconciliatorOp.emitOpError("non-dense reconciliator lowering is not supported yet");
|
||||
if (blueprintOp.getPhysicalLayout() != kRowStripLayout) {
|
||||
blueprintOp.emitOpError("non-dense blueprint lowering is not supported yet");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
if (!eraseAfterLowering.contains(reconciliatorOp)) {
|
||||
reconciliatorOp.emitOpError("unhandled row-strip reconciliator remained during LowerSpatialPlans");
|
||||
if (!eraseAfterLowering.contains(blueprintOp)) {
|
||||
blueprintOp.emitOpError("unhandled row-strip blueprint remained during LowerSpatialPlans");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -385,7 +385,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
return;
|
||||
if (isa<spatial::SpatConv2DPlanOp,
|
||||
spatial::SpatReluPlanOp,
|
||||
spatial::SpatReconciliatorOp,
|
||||
spatial::SpatBlueprintOp,
|
||||
spatial::SpatMaterializeLayoutOp>(op)
|
||||
|| op->getDialect()->getNamespace() == "onnx") {
|
||||
op->emitOpError("operation must not remain after LowerSpatialPlans");
|
||||
|
||||
Reference in New Issue
Block a user