This commit is contained in:
@@ -163,6 +163,38 @@ Value extractAxisSlice(
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value extractStaticSliceOrIdentity(RewriterBase& rewriter,
|
||||
Location loc,
|
||||
Value source,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes,
|
||||
ArrayRef<OpFoldResult> strides) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
size_t rank = static_cast<size_t>(sourceType.getRank());
|
||||
|
||||
bool isIdentitySlice =
|
||||
sourceType == resultType && sourceType.hasStaticShape() && offsets.size() == rank && sizes.size() == rank
|
||||
&& strides.size() == rank;
|
||||
if (isIdentitySlice) {
|
||||
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
||||
for (auto [dim, offset, size, stride] : llvm::zip_equal(sourceShape, offsets, sizes, strides)) {
|
||||
std::optional<int64_t> staticOffset = mlir::getConstantIntValue(offset);
|
||||
std::optional<int64_t> staticSize = mlir::getConstantIntValue(size);
|
||||
std::optional<int64_t> staticStride = mlir::getConstantIntValue(stride);
|
||||
if (!staticOffset || !staticSize || !staticStride || *staticOffset != 0 || *staticSize != dim || *staticStride != 1) {
|
||||
isIdentitySlice = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isIdentitySlice)
|
||||
return source;
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
Value insertStaticSlice(
|
||||
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
|
||||
@@ -105,6 +105,14 @@ llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPer
|
||||
mlir::Value extractAxisSlice(
|
||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
||||
|
||||
mlir::Value extractStaticSliceOrIdentity(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
mlir::RankedTensorType resultType,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> offsets,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> sizes,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> strides);
|
||||
|
||||
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
|
||||
@@ -44,17 +44,17 @@ static FailureOr<RowStripPhysicalValue> getRowStripValue(llvm::DenseMap<Value, R
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatReconciliatorOp reconciliator,
|
||||
static FailureOr<RowStripPhysicalValue> buildRowStripValue(spatial::SpatBlueprintOp blueprint,
|
||||
Value physicalValue) {
|
||||
auto logicalType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
|
||||
auto logicalType = dyn_cast<RankedTensorType>(blueprint.getOutput().getType());
|
||||
if (!logicalType)
|
||||
return reconciliator.emitOpError("requires ranked logical output type"), failure();
|
||||
return blueprint.emitOpError("requires ranked logical output type"), failure();
|
||||
RowStripPhysicalValue value;
|
||||
value.physicalValue = physicalValue;
|
||||
value.logicalType = logicalType;
|
||||
value.fragmentOffsets.append(reconciliator.getFragmentOffsets().begin(), reconciliator.getFragmentOffsets().end());
|
||||
value.fragmentSizes.append(reconciliator.getFragmentSizes().begin(), reconciliator.getFragmentSizes().end());
|
||||
value.indexMap = reconciliator.getIndexMap().str();
|
||||
value.fragmentOffsets.append(blueprint.getFragmentOffsets().begin(), blueprint.getFragmentOffsets().end());
|
||||
value.fragmentSizes.append(blueprint.getFragmentSizes().begin(), blueprint.getFragmentSizes().end());
|
||||
value.indexMap = blueprint.getIndexMap().str();
|
||||
return value;
|
||||
}
|
||||
|
||||
@@ -175,7 +175,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
auto verifyLogicalPhase = [&](StringRef stage) -> bool {
|
||||
if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc)))
|
||||
return true;
|
||||
moduleOp.emitError() << "RAPTOR_PHASE_CHECK logical Spatial graph verification failed " << stage;
|
||||
moduleOp.emitError() << "logical Spatial graph verification failed " << stage;
|
||||
signalPassFailure();
|
||||
return false;
|
||||
};
|
||||
@@ -185,11 +185,11 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) {
|
||||
if (auto planOp = dyn_cast<spatial::SpatConv2DPlanOp>(&op)) {
|
||||
FailureOr<RowStripPhysicalValue> rowStripInput = getRowStripValue(rowStripValues, planOp.getInput());
|
||||
auto rowStripReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
|
||||
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
|
||||
auto rowStripBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(user);
|
||||
return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout;
|
||||
});
|
||||
if (rowStripReconciliator != planOp.getResult().getUsers().end()) {
|
||||
if (rowStripBlueprint != planOp.getResult().getUsers().end()) {
|
||||
rewriter.setInsertionPoint(planOp);
|
||||
FailureOr<Value> lowered = lowerSelectedConv2DPlan(
|
||||
planOp,
|
||||
@@ -201,15 +201,15 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*rowStripReconciliator);
|
||||
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(reconciliator, *lowered);
|
||||
auto blueprint = cast<spatial::SpatBlueprintOp>(*rowStripBlueprint);
|
||||
FailureOr<RowStripPhysicalValue> rowStripValue = buildRowStripValue(blueprint, *lowered);
|
||||
if (failed(rowStripValue)) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
rowStripValues[reconciliator.getResult()] = *rowStripValue;
|
||||
rowStripValues[blueprint.getResult()] = *rowStripValue;
|
||||
eraseAfterLowering.insert(planOp);
|
||||
eraseAfterLowering.insert(reconciliator);
|
||||
eraseAfterLowering.insert(blueprint);
|
||||
continue;
|
||||
}
|
||||
rewriter.setInsertionPoint(planOp);
|
||||
@@ -226,12 +226,12 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
|
||||
if (auto planOp = dyn_cast<spatial::SpatReluPlanOp>(&op)) {
|
||||
if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) {
|
||||
auto outputReconciliator = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||
auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(user);
|
||||
return reconciliator && reconciliator.getPhysicalLayout() == kRowStripLayout;
|
||||
auto outputBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) {
|
||||
auto blueprint = dyn_cast<spatial::SpatBlueprintOp>(user);
|
||||
return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout;
|
||||
});
|
||||
if (outputReconciliator == planOp.getResult().getUsers().end()) {
|
||||
planOp.emitOpError("row-strip Relu plan requires a row-strip reconciliator result");
|
||||
if (outputBlueprint == planOp.getResult().getUsers().end()) {
|
||||
planOp.emitOpError("row-strip Relu plan requires a row-strip blueprint result");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -244,15 +244,15 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
auto reconciliator = cast<spatial::SpatReconciliatorOp>(*outputReconciliator);
|
||||
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(reconciliator, *lowered);
|
||||
auto blueprint = cast<spatial::SpatBlueprintOp>(*outputBlueprint);
|
||||
FailureOr<RowStripPhysicalValue> output = buildRowStripValue(blueprint, *lowered);
|
||||
if (failed(output)) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
rowStripValues[reconciliator.getResult()] = *output;
|
||||
rowStripValues[blueprint.getResult()] = *output;
|
||||
eraseAfterLowering.insert(planOp);
|
||||
eraseAfterLowering.insert(reconciliator);
|
||||
eraseAfterLowering.insert(blueprint);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -279,7 +279,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
}
|
||||
FailureOr<RowStripPhysicalValue> rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput());
|
||||
if (failed(rowStripValue)) {
|
||||
materializeOp.emitOpError("expected a row-strip reconciliator input during row-strip materialization");
|
||||
materializeOp.emitOpError("expected a row-strip blueprint input during row-strip materialization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -293,18 +293,18 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
rewriter.replaceOp(materializeOp, *dense);
|
||||
continue;
|
||||
}
|
||||
if (auto reconciliatorOp = dyn_cast<spatial::SpatReconciliatorOp>(&op)) {
|
||||
if (reconciliatorOp.getPhysicalLayout() == kDenseLayout) {
|
||||
rewriter.replaceOp(reconciliatorOp, reconciliatorOp.getInput());
|
||||
if (auto blueprintOp = dyn_cast<spatial::SpatBlueprintOp>(&op)) {
|
||||
if (blueprintOp.getPhysicalLayout() == kDenseLayout) {
|
||||
rewriter.replaceOp(blueprintOp, blueprintOp.getInput());
|
||||
continue;
|
||||
}
|
||||
if (reconciliatorOp.getPhysicalLayout() != kRowStripLayout) {
|
||||
reconciliatorOp.emitOpError("non-dense reconciliator lowering is not supported yet");
|
||||
if (blueprintOp.getPhysicalLayout() != kRowStripLayout) {
|
||||
blueprintOp.emitOpError("non-dense blueprint lowering is not supported yet");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
if (!eraseAfterLowering.contains(reconciliatorOp)) {
|
||||
reconciliatorOp.emitOpError("unhandled row-strip reconciliator remained during LowerSpatialPlans");
|
||||
if (!eraseAfterLowering.contains(blueprintOp)) {
|
||||
blueprintOp.emitOpError("unhandled row-strip blueprint remained during LowerSpatialPlans");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -385,7 +385,7 @@ struct LowerSpatialPlansPass final : PassWrapper<LowerSpatialPlansPass, Operatio
|
||||
return;
|
||||
if (isa<spatial::SpatConv2DPlanOp,
|
||||
spatial::SpatReluPlanOp,
|
||||
spatial::SpatReconciliatorOp,
|
||||
spatial::SpatBlueprintOp,
|
||||
spatial::SpatMaterializeLayoutOp>(op)
|
||||
|| op->getDialect()->getNamespace() == "onnx") {
|
||||
op->emitOpError("operation must not remain after LowerSpatialPlans");
|
||||
|
||||
@@ -46,9 +46,9 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
SmallVector<spatial::SpatGraphComputeBatch> computeBatches(funcOp.getOps<spatial::SpatGraphComputeBatch>());
|
||||
SmallVector<spatial::SpatConv2DPlanOp> convPlans(funcOp.getOps<spatial::SpatConv2DPlanOp>());
|
||||
SmallVector<spatial::SpatReluPlanOp> reluPlans(funcOp.getOps<spatial::SpatReluPlanOp>());
|
||||
SmallVector<spatial::SpatReconciliatorOp> reconciliators(funcOp.getOps<spatial::SpatReconciliatorOp>());
|
||||
SmallVector<spatial::SpatBlueprintOp> blueprints(funcOp.getOps<spatial::SpatBlueprintOp>());
|
||||
SmallVector<spatial::SpatMaterializeLayoutOp> materializers(funcOp.getOps<spatial::SpatMaterializeLayoutOp>());
|
||||
if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !reconciliators.empty()
|
||||
if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !blueprints.empty()
|
||||
|| !materializers.empty()) {
|
||||
return;
|
||||
}
|
||||
@@ -160,7 +160,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
}
|
||||
|
||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX conversion");
|
||||
moduleOp.emitError("logical Spatial graph verification failed after ONNX conversion");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -181,7 +181,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after weight annotation");
|
||||
moduleOp.emitError("logical Spatial graph verification failed after weight annotation");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -199,7 +199,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
[](spatial::SpatGraphComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
||||
|
||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed before post rewrites");
|
||||
moduleOp.emitError("logical Spatial graph verification failed before post rewrites");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -214,7 +214,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
populateEmptyFunction(*entryFunc);
|
||||
|
||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX-to-Spatial");
|
||||
moduleOp.emitError("logical Spatial graph verification failed after ONNX-to-Spatial");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr StringLiteral kPhaseMarker = "RAPTOR_PHASE_CHECK";
|
||||
constexpr StringLiteral kPhaseMarker = "phase-check";
|
||||
|
||||
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
func.walk([&](Operation* op) {
|
||||
@@ -114,14 +114,14 @@ void verifyScheduledInputs(ComputeOpTy compute,
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
void verifyNoNestedFragmentAssemblyReconciliators(ComputeOpTy compute,
|
||||
void verifyNoNestedFragmentAssemblyBlueprints(ComputeOpTy compute,
|
||||
pim::CappedDiagnosticReporter& diagnostics) {
|
||||
compute.getBody().walk([&](spatial::SpatReconciliatorOp reconciliator) {
|
||||
std::optional<StringRef> mode = reconciliator.getMode();
|
||||
compute.getBody().walk([&](spatial::SpatBlueprintOp blueprint) {
|
||||
std::optional<StringRef> mode = blueprint.getMode();
|
||||
if (!mode || *mode != "fragment_assembly")
|
||||
return;
|
||||
diagnostics.report(reconciliator.getOperation(), [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("fragment assembly reconciliator must be host-level after merge materialization");
|
||||
diagnostics.report(blueprint.getOperation(), [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("fragment assembly blueprint must be host-level after merge materialization");
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -133,7 +133,7 @@ void verifyLogicalTopLevelOps(func::FuncOp funcOp, pim::CappedDiagnosticReporter
|
||||
spatial::SpatGraphComputeBatch,
|
||||
spatial::SpatConv2DPlanOp,
|
||||
spatial::SpatReluPlanOp,
|
||||
spatial::SpatReconciliatorOp,
|
||||
spatial::SpatBlueprintOp,
|
||||
spatial::SpatMaterializeLayoutOp>(&op)) {
|
||||
continue;
|
||||
}
|
||||
@@ -203,11 +203,11 @@ LogicalResult verifyScheduledSpatialInvariants(func::FuncOp funcOp) {
|
||||
verifyScheduledTopLevelOps(funcOp, diagnostics);
|
||||
for (auto compute : funcOp.getOps<spatial::SpatScheduledCompute>()) {
|
||||
verifyScheduledInputs(compute, /*allowChannelReceiveInputs=*/true, "spat.scheduled_compute", diagnostics);
|
||||
verifyNoNestedFragmentAssemblyReconciliators(compute, diagnostics);
|
||||
verifyNoNestedFragmentAssemblyBlueprints(compute, diagnostics);
|
||||
}
|
||||
for (auto batch : funcOp.getOps<spatial::SpatScheduledComputeBatch>()) {
|
||||
verifyScheduledInputs(batch, /*allowChannelReceiveInputs=*/false, "spat.scheduled_compute_batch", diagnostics);
|
||||
verifyNoNestedFragmentAssemblyReconciliators(batch, diagnostics);
|
||||
verifyNoNestedFragmentAssemblyBlueprints(batch, diagnostics);
|
||||
}
|
||||
if (failed(verifyNoComputeBodyCaptures(funcOp)))
|
||||
return failure();
|
||||
|
||||
@@ -2242,8 +2242,8 @@ static FailureOr<Value> rewriteInputKTiledConv(const ConvLoweringState& state,
|
||||
rewriter, reduceLoc, paddedRowType, paddedPatchRow, aOffsets, aSizes, unitStrides);
|
||||
SmallVector<OpFoldResult> bOffsets {kOffset, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)};
|
||||
Value bTile = tensor::ExtractSliceOp::create(
|
||||
rewriter, reduceLoc, weightTileType, weightArg, bOffsets, bSizes, unitStrides);
|
||||
Value bTile = extractStaticSliceOrIdentity(
|
||||
rewriter, reduceLoc, weightArg, weightTileType, bOffsets, bSizes, unitStrides);
|
||||
Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult();
|
||||
reduceYielded.push_back(
|
||||
spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, acc, piece).getResult());
|
||||
@@ -2912,8 +2912,13 @@ static FailureOr<Value> createConvOutputFromRowStripHwc(Value inputHwc,
|
||||
rewriter, reduceLoc, paddedRowType, paddedRow, aOffsets, aSizes, getUnitStrides(rewriter, 2));
|
||||
SmallVector<OpFoldResult> bOffsets {kOffset, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(xbarDim), rewriter.getIndexAttr(xbarDim)};
|
||||
Value bTile = tensor::ExtractSliceOp::create(
|
||||
rewriter, reduceLoc, paddedWeightTileType, args.weights.front(), bOffsets, bSizes, getUnitStrides(rewriter, 2));
|
||||
Value bTile = extractStaticSliceOrIdentity(rewriter,
|
||||
reduceLoc,
|
||||
args.weights.front(),
|
||||
paddedWeightTileType,
|
||||
bOffsets,
|
||||
bSizes,
|
||||
getUnitStrides(rewriter, 2));
|
||||
Value piece = spatial::SpatVMMOp::create(rewriter, reduceLoc, paddedRowType, bTile, aTile).getResult();
|
||||
reduceYielded.push_back(
|
||||
spatial::SpatVAddOp::create(rewriter, reduceLoc, paddedRowType, reduceIterArgs.front(), piece).getResult());
|
||||
|
||||
@@ -285,9 +285,8 @@ static FailureOr<spatial::SpatComputeBatch> createVmmBatch(Value a,
|
||||
SmallVector<OpFoldResult> bSizes {rewriter.getIndexAttr(crossbarSize.getValue()),
|
||||
rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, 2);
|
||||
Value bTile =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides)
|
||||
.getResult();
|
||||
Value bTile = extractStaticSliceOrIdentity(
|
||||
rewriter, loc, args.weights.front(), bTileType, bOffsets, bSizes, unitStrides);
|
||||
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
||||
|
||||
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||
|
||||
@@ -90,10 +90,10 @@ static SelectedLayout chooseReluLayout(spatial::SpatReluPlanOp reluPlan,
|
||||
return SelectedLayout::NchwRowStrip;
|
||||
}
|
||||
|
||||
static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewriter, Value value) {
|
||||
static spatial::SpatBlueprintOp insertRowStripBlueprint(IRRewriter& rewriter, Value value) {
|
||||
auto outputType = cast<RankedTensorType>(value.getType());
|
||||
auto [offsets, sizes] = buildRowStripMetadata(outputType);
|
||||
return spatial::SpatReconciliatorOp::create(rewriter,
|
||||
return spatial::SpatBlueprintOp::create(rewriter,
|
||||
value.getLoc(),
|
||||
outputType,
|
||||
value,
|
||||
@@ -189,12 +189,12 @@ struct SpatialLayoutPlanningPass final : PassWrapper<SpatialLayoutPlanningPass,
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPointAfter(&op);
|
||||
auto reconciliator = insertRowStripReconciliator(rewriter, producedValue);
|
||||
rewriter.replaceAllUsesExcept(producedValue, reconciliator.getResult(), reconciliator);
|
||||
materializeDenseUses(rewriter, reconciliator.getResult(), layouts);
|
||||
auto blueprint = insertRowStripBlueprint(rewriter, producedValue);
|
||||
rewriter.replaceAllUsesExcept(producedValue, blueprint.getResult(), blueprint);
|
||||
materializeDenseUses(rewriter, blueprint.getResult(), layouts);
|
||||
}
|
||||
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
|
||||
getOperation().emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after SpatialLayoutPlanning");
|
||||
getOperation().emitError("logical Spatial graph verification failed after SpatialLayoutPlanning");
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user