#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" #include "Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/DebugDump.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static constexpr StringLiteral kDenseLayout = "dense_nchw"; static constexpr StringLiteral kRowStripLayout = "nchw_row_strip"; struct RowStripPhysicalValue { Value physicalValue; RankedTensorType logicalType; SmallVector fragmentOffsets; SmallVector fragmentSizes; std::string indexMap; }; static FailureOr getRowStripValue(llvm::DenseMap& rowStripValues, Value value) { auto it = rowStripValues.find(value); if (it == rowStripValues.end()) return failure(); return it->second; } static FailureOr buildRowStripValue(spatial::SpatBlueprintOp blueprint, Value physicalValue) { auto logicalType = dyn_cast(blueprint.getOutput().getType()); if (!logicalType) return blueprint.emitOpError("requires ranked logical output type"), failure(); RowStripPhysicalValue value; value.physicalValue = physicalValue; value.logicalType = logicalType; value.fragmentOffsets.append(blueprint.getFragmentOffsets().begin(), blueprint.getFragmentOffsets().end()); value.fragmentSizes.append(blueprint.getFragmentSizes().begin(), blueprint.getFragmentSizes().end()); value.indexMap = blueprint.getIndexMap().str(); return value; } static FailureOr lowerRowStripRelu(const RowStripPhysicalValue& input, spatial::SpatReluPlanOp planOp, PatternRewriter& rewriter) { auto packedType = cast(input.physicalValue.getType()); auto computeOp = createSpatCompute<1>(rewriter, planOp.getLoc(), TypeRange {packedType}, {}, input.physicalValue, [&](Value x) { auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), packedType, x); spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult()); }); return computeOp.getResult(0); } static FailureOr materializeRowStripToDense(const RowStripPhysicalValue& rowStripValue, Location loc, PatternRewriter& rewriter) { auto packedType = dyn_cast(rowStripValue.physicalValue.getType()); if (!packedType || packedType.getRank() != 3 || !packedType.hasStaticShape()) return failure(); if (rowStripValue.logicalType.getRank() != 4 || !rowStripValue.logicalType.hasStaticShape()) return failure(); if (rowStripValue.indexMap != "packed_hwc_rows_to_nchw") return failure(); const int64_t rank = rowStripValue.logicalType.getRank(); const int64_t fragmentCount = rowStripValue.fragmentOffsets.size() / rank; const int64_t packedWidth = packedType.getDimSize(1); const int64_t packedChannels = packedType.getDimSize(2); if (fragmentCount != packedType.getDimSize(0)) return failure(); for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) { if (rowStripValue.fragmentOffsets[fragmentIndex * rank + 0] != 0 || rowStripValue.fragmentOffsets[fragmentIndex * rank + 1] != 0 || rowStripValue.fragmentOffsets[fragmentIndex * rank + 2] != fragmentIndex || rowStripValue.fragmentOffsets[fragmentIndex * rank + 3] != 0) return failure(); if (rowStripValue.fragmentSizes[fragmentIndex * rank + 0] != 1 || rowStripValue.fragmentSizes[fragmentIndex * rank + 1] != packedChannels || rowStripValue.fragmentSizes[fragmentIndex * rank + 2] != 1 || rowStripValue.fragmentSizes[fragmentIndex * rank + 3] != packedWidth) return failure(); } auto packedSliceType = RankedTensorType::get({1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding()); auto expandedType = RankedTensorType::get({1, 1, packedWidth, packedChannels}, packedType.getElementType(), packedType.getEncoding()); auto logicalFragmentType = RankedTensorType::get({1, packedChannels, 1, packedWidth}, packedType.getElementType(), packedType.getEncoding()); auto batchOp = createSpatComputeBatch( rewriter, loc, TypeRange {rowStripValue.logicalType}, fragmentCount, {}, ValueRange {rowStripValue.physicalValue}, [&](detail::SpatComputeBatchBodyArgs args) { SmallVector packedOffsets {args.lane, rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector packedSizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(packedWidth), rewriter.getIndexAttr(packedChannels)}; Value packedSlice = tensor::ExtractSliceOp::create( rewriter, loc, packedSliceType, args.inputs.front(), packedOffsets, packedSizes, getUnitStrides(rewriter, 3)); Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, packedSlice, SmallVector { {0, 1}, {2}, {3} }); Value transposeInit = tensor::EmptyOp::create(rewriter, loc, logicalFragmentType.getShape(), logicalFragmentType.getElementType()); Value logicalFragment = linalg::TransposeOp::create(rewriter, loc, expanded, transposeInit, SmallVector {0, 3, 1, 2}) .getResult()[0]; SmallVector logicalOffsets { rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), args.lane, rewriter.getIndexAttr(0)}; SmallVector logicalSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(packedChannels), rewriter.getIndexAttr(1), rewriter.getIndexAttr(packedWidth)}; createParallelInsertSliceIntoBatchOutput(rewriter, loc, logicalFragment, args.outputs.front(), logicalOffsets, logicalSizes, getUnitStrides(rewriter, 4)); return success(); }); if (failed(batchOp)) return failure(); return batchOp->getResult(0); } struct LowerSpatialPlansPass final : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerSpatialPlansPass) StringRef getArgument() const override { return "lower-spatial-plans"; } StringRef getDescription() const override { return "Lower selected Spatial planning ops to low-level Spatial IR."; } void runOnOperation() override { ModuleOp moduleOp = getOperation(); MLIRContext* ctx = moduleOp.getContext(); auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { moduleOp.emitError("failed to locate the PIM entry function during LowerSpatialPlans"); signalPassFailure(); return; } func::FuncOp funcOp = *entryFunc; PatternRewriter rewriter(ctx); llvm::DenseMap rowStripValues; llvm::SmallPtrSet eraseAfterLowering; auto verifyLogicalPhase = [&](StringRef stage) -> bool { if (succeeded(verifyLogicalSpatialGraphInvariants(*entryFunc))) return true; moduleOp.emitError() << "logical Spatial graph verification failed " << stage; signalPassFailure(); return false; }; if (!verifyLogicalPhase("at the start of LowerSpatialPlans")) return; for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) { if (auto planOp = dyn_cast(&op)) { FailureOr rowStripInput = getRowStripValue(rowStripValues, planOp.getInput()); auto rowStripBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) { auto blueprint = dyn_cast(user); return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout; }); if (rowStripBlueprint != planOp.getResult().getUsers().end()) { rewriter.setInsertionPoint(planOp); FailureOr lowered = lowerSelectedConv2DPlan( planOp, succeeded(rowStripInput) ? std::optional {rowStripInput->physicalValue} : std::nullopt, /*emitRowStripLayout=*/true, rewriter); if (failed(lowered)) { planOp.emitOpError("failed to lower selected row-strip Spatial Conv plan"); signalPassFailure(); return; } auto blueprint = cast(*rowStripBlueprint); FailureOr rowStripValue = buildRowStripValue(blueprint, *lowered); if (failed(rowStripValue)) { signalPassFailure(); return; } rowStripValues[blueprint.getResult()] = *rowStripValue; eraseAfterLowering.insert(planOp); eraseAfterLowering.insert(blueprint); continue; } rewriter.setInsertionPoint(planOp); FailureOr lowered = lowerSelectedConv2DPlan(planOp, std::nullopt, /*emitRowStripLayout=*/false, rewriter); if (failed(lowered)) { planOp.emitOpError("failed to lower selected Spatial Conv plan"); signalPassFailure(); return; } rewriter.replaceOp(planOp, *lowered); continue; } if (auto planOp = dyn_cast(&op)) { if (succeeded(getRowStripValue(rowStripValues, planOp.getInput()))) { auto outputBlueprint = llvm::find_if(planOp.getResult().getUsers(), [](Operation* user) { auto blueprint = dyn_cast(user); return blueprint && blueprint.getPhysicalLayout() == kRowStripLayout; }); if (outputBlueprint == planOp.getResult().getUsers().end()) { planOp.emitOpError("row-strip Relu plan requires a row-strip blueprint result"); signalPassFailure(); return; } FailureOr input = getRowStripValue(rowStripValues, planOp.getInput()); rewriter.setInsertionPoint(planOp); FailureOr lowered = lowerRowStripRelu(*input, planOp, rewriter); if (failed(lowered)) { planOp.emitOpError("failed to lower selected row-strip Spatial Relu plan"); signalPassFailure(); return; } auto blueprint = cast(*outputBlueprint); FailureOr output = buildRowStripValue(blueprint, *lowered); if (failed(output)) { signalPassFailure(); return; } rowStripValues[blueprint.getResult()] = *output; eraseAfterLowering.insert(planOp); eraseAfterLowering.insert(blueprint); continue; } rewriter.setInsertionPoint(planOp); auto computeOp = createSpatCompute<1>( rewriter, planOp.getLoc(), planOp.getOutput().getType(), {}, planOp.getInput(), [&](Value x) { auto relu = spatial::SpatReluOp::create(rewriter, planOp.getLoc(), planOp.getOutput().getType(), x); spatial::SpatYieldOp::create(rewriter, planOp.getLoc(), relu.getResult()); }); rewriter.replaceOp(planOp, computeOp.getResults()); continue; } if (auto materializeOp = dyn_cast(&op)) { if (materializeOp.getSourcePhysicalLayout() == kDenseLayout && materializeOp.getTargetPhysicalLayout() == kDenseLayout) { rewriter.replaceOp(materializeOp, materializeOp.getInput()); continue; } if (materializeOp.getSourcePhysicalLayout() != kRowStripLayout || materializeOp.getTargetPhysicalLayout() != kDenseLayout) { materializeOp.emitOpError("non-dense materialize_layout lowering is not supported yet"); signalPassFailure(); return; } FailureOr rowStripValue = getRowStripValue(rowStripValues, materializeOp.getInput()); if (failed(rowStripValue)) { materializeOp.emitOpError("expected a row-strip blueprint input during row-strip materialization"); signalPassFailure(); return; } rewriter.setInsertionPoint(materializeOp); FailureOr dense = materializeRowStripToDense(*rowStripValue, materializeOp.getLoc(), rewriter); if (failed(dense)) { materializeOp.emitOpError("failed to materialize selected row-strip layout back to dense NCHW"); signalPassFailure(); return; } rewriter.replaceOp(materializeOp, *dense); continue; } if (auto blueprintOp = dyn_cast(&op)) { if (blueprintOp.getPhysicalLayout() == kDenseLayout) { rewriter.replaceOp(blueprintOp, blueprintOp.getInput()); continue; } if (blueprintOp.getPhysicalLayout() != kRowStripLayout) { blueprintOp.emitOpError("non-dense blueprint lowering is not supported yet"); signalPassFailure(); return; } if (!eraseAfterLowering.contains(blueprintOp)) { blueprintOp.emitOpError("unhandled row-strip blueprint remained during LowerSpatialPlans"); signalPassFailure(); return; } } } bool erasedAny = true; while (erasedAny) { erasedAny = false; for (Operation& op : llvm::make_early_inc_range(funcOp.getBody().front())) { if (!eraseAfterLowering.contains(&op)) continue; if (!op.use_empty()) continue; eraseAfterLowering.erase(&op); rewriter.eraseOp(&op); erasedAny = true; } } if (!eraseAfterLowering.empty()) { for (Operation& op : funcOp.getBody().front()) if (eraseAfterLowering.contains(&op)) op.emitOpError("selected row-strip planning op could not be fully eliminated during LowerSpatialPlans"); signalPassFailure(); return; } ConversionTarget helperTarget(*ctx); helperTarget.addLegalDialect(); helperTarget.addLegalOp(); helperTarget.addIllegalOp(); helperTarget.markOpRecursivelyLegal(); RewritePatternSet helperPatterns(ctx); populateGemmPatterns(helperPatterns, ctx); populateTransposePatterns(helperPatterns, ctx); if (failed(applyPartialConversion(moduleOp, helperTarget, std::move(helperPatterns)))) { moduleOp.emitError("failed to lower helper ONNX ops emitted by selected Spatial plan lowering"); signalPassFailure(); return; } FrozenRewritePatternSet nestedHelperPatterns([&] { RewritePatternSet patterns(ctx); populateGemmPatterns(patterns, ctx); populateTransposePatterns(patterns, ctx); return patterns; }()); ConversionTarget nestedHelperTarget(*ctx); nestedHelperTarget.addLegalDialect(); nestedHelperTarget.addIllegalOp(); SmallVector computeLikeOps; funcOp.walk([&](Operation* op) { if (isa(op)) computeLikeOps.push_back(op); }); for (Operation* op : computeLikeOps) { if (failed(applyFullConversion(op, nestedHelperTarget, nestedHelperPatterns))) { op->emitOpError("failed to lower nested helper ONNX ops emitted by selected Spatial plan lowering"); signalPassFailure(); return; } } if (!verifyLogicalPhase("after nested helper conversions")) return; bool hasIllegalOps = false; moduleOp.walk([&](Operation* op) { if (isa(op)) return; if (isa(op) || op->getDialect()->getNamespace() == "onnx") { op->emitOpError("operation must not remain after LowerSpatialPlans"); hasIllegalOps = true; } }); if (hasIllegalOps) signalPassFailure(); else dumpModule(moduleOp, "spatial1_premerge"); if (!verifyLogicalPhase("at the end of LowerSpatialPlans")) return; } }; } // namespace std::unique_ptr createLowerSpatialPlansPass() { return std::make_unique(); } } // namespace onnx_mlir