Files
Raptor/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp
T
NiccoloN 11916a2595 refactor Pim constant folding pass
share contiguous address resolution in PimCommon
group patterns in subdir for each pass with pattern files
2026-03-23 15:36:58 +01:00

89 lines
3.0 KiB
C++

#pragma once
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include <functional>
#include <unordered_map>
#include <utility>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
using ResNum = unsigned int;
using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>;
struct SpatialReducerChange {
mlir::Operation* fromOp;
unsigned int fromOpResNum;
mlir::Operation* toOp;
unsigned int toOpOperandNum;
};
using OpAndResNum = std::pair<mlir::Operation*, ResNum>;
class SpatialReducer {
public:
SpatialReducer(mlir::ConversionPatternRewriter& rewriter)
: rewriter(rewriter) {}
OpAndResNum applyReducePattern(llvm::SmallVector<ComputeAndResNum>& computeOpsAndResNum,
std::function<mlir::Value(const mlir::Value&, const mlir::Value&)> reduce,
std::function<mlir::Value(const mlir::Value&)> preprocess,
std::function<mlir::Value(const mlir::Value&)> postprocess);
OpAndResNum applyAddMapReduction(llvm::SmallVector<ComputeAndResNum>& computeOps,
mlir::ConversionPatternRewriter& rewriter,
mlir::Value biasTile,
MapOperations mapOp);
void finalizeReduceUpdates();
~SpatialReducer() {
if (!reducesFinalized)
finalizeReduceUpdates();
}
mlir::Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
mlir::Location& loc,
mlir::Type outputType);
mlir::Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum);
private:
[[nodiscard("computeOp result number gets updated")]] ResNum
applyResultProcessing(ComputeAndResNum computeOpAndResNum,
std::function<mlir::Value(const mlir::Value&)> processFun,
mlir::ConversionPatternRewriter& rewriter);
/**
* @brief Update the results of a ComputeOp.
*
* This function updates the results of a ComputeOp by taking a look at the
operands of its yieldOp.
* If the ComputeOp was replaced, it updates `opToReplacedCompute` with the
replaced ComputeOp.
*
* @param computeOp The ComputeOp to update the results of.
*/
void updateResultsOfCompute(mlir::Operation* computeOp);
mlir::ConversionPatternRewriter& rewriter;
bool reducesFinalized = false;
// List of changes to be applied after the reduction is finalized
llvm::SmallVector<SpatialReducerChange, 4> reducerChanges;
// List of computeOps that need to be replaced with new results
llvm::SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate;
std::unordered_map<mlir::Operation*, spatial::SpatWeightedCompute> opToReplacedCompute;
static llvm::SmallPtrSet<mlir::Operation*, 16> oldComputeOpsReplaced;
};
} // namespace onnx_mlir