#pragma once #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" namespace onnx_mlir { using ResNum = unsigned int; using ComputeAndResNum = std::pair; struct SpatialReducerChange { Operation* fromOp; unsigned int fromOpResNum; Operation* toOp; unsigned int toOpOperandNum; }; using OpAndResNum = std::pair; class SpatialReducer { public: SpatialReducer(ConversionPatternRewriter& rewriter) : rewriter(rewriter) {} OpAndResNum applyReducePattern(SmallVector& computeOpsAndResNum, std::function reduce, std::function preprocess, std::function postprocess); OpAndResNum applyAddMapReduction(SmallVector& computeOps, ConversionPatternRewriter& rewriter, Value biasTile, MapOperations mapOp); void finalizeReduceUpdates(); ~SpatialReducer() { if (!reducesFinalized) finalizeReduceUpdates(); } Value createImgConcatOp(llvm::SmallVector>>& outputTiles, Location& loc, Type outputType); Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum); private: [[nodiscard("computeOp result number gets updated")]] ResNum applyResultProcessing(ComputeAndResNum computeOpAndResNum, std::function processFun, ConversionPatternRewriter& rewriter); /** * @brief Update the results of a ComputeOp. * * This function updates the results of a ComputeOp by taking a look at the operands of its yieldOp. * If the ComputeOp was replaced, it updates `opToReplacedCompute` with the replaced ComputeOp. * * @param computeOp The ComputeOp to update the results of. */ void updateResultsOfCompute(Operation* computeOp); ConversionPatternRewriter& rewriter; bool reducesFinalized = false; // List of changes to be applied after the reduction is finalized SmallVector reducerChanges; // List of computeOps that need to be replaced with new results SmallVector computeOpNeedingResUpdate; std::unordered_map opToReplacedCompute; static llvm::SmallPtrSet oldComputeOpsReplaced; }; } // namespace onnx_mlir