85 lines
2.7 KiB
C++
85 lines
2.7 KiB
C++
#pragma once
|
|
|
|
#include "llvm/ADT/SmallPtrSet.h"
|
|
#include "llvm/Support/Casting.h"
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
|
|
namespace onnx_mlir {
|
|
|
|
using ResNum = unsigned int;
|
|
|
|
using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>;
|
|
|
|
struct SpatialReducerChange {
|
|
Operation* fromOp;
|
|
unsigned int fromOpResNum;
|
|
Operation* toOp;
|
|
unsigned int toOpOperandNum;
|
|
};
|
|
|
|
using OpAndResNum = std::pair<Operation*, ResNum>;
|
|
|
|
class SpatialReducer {
|
|
|
|
public:
|
|
SpatialReducer(ConversionPatternRewriter& rewriter)
|
|
: rewriter(rewriter) {}
|
|
|
|
OpAndResNum applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
|
|
std::function<Value(const Value&, const Value&)> reduce,
|
|
std::function<Value(const Value&)> preprocess,
|
|
std::function<Value(const Value&)> postprocess);
|
|
|
|
OpAndResNum applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
|
|
ConversionPatternRewriter& rewriter,
|
|
Value biasTile,
|
|
MapOperations mapOp);
|
|
|
|
void finalizeReduceUpdates();
|
|
|
|
~SpatialReducer() {
|
|
if (!reducesFinalized)
|
|
finalizeReduceUpdates();
|
|
}
|
|
|
|
Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
|
|
Location& loc,
|
|
Type outputType);
|
|
|
|
Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum);
|
|
|
|
private:
|
|
[[nodiscard("computeOp result number gets updated")]] ResNum
|
|
applyResultProcessing(ComputeAndResNum computeOpAndResNum,
|
|
std::function<Value(const Value&)> processFun,
|
|
ConversionPatternRewriter& rewriter);
|
|
|
|
/**
|
|
* @brief Update the results of a ComputeOp.
|
|
*
|
|
* This function updates the results of a ComputeOp by taking a look at the
|
|
operands of its yieldOp.
|
|
* If the ComputeOp was replaced, it updates `opToReplacedCompute` with the
|
|
replaced ComputeOp.
|
|
*
|
|
* @param computeOp The ComputeOp to update the results of.
|
|
*/
|
|
void updateResultsOfCompute(Operation* computeOp);
|
|
|
|
ConversionPatternRewriter& rewriter;
|
|
bool reducesFinalized = false;
|
|
|
|
// List of changes to be applied after the reduction is finalized
|
|
SmallVector<SpatialReducerChange, 4> reducerChanges;
|
|
// List of computeOps that need to be replaced with new results
|
|
SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate;
|
|
|
|
std::unordered_map<Operation*, spatial::SpatWeightedCompute> opToReplacedCompute;
|
|
|
|
static llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
|
|
};
|
|
|
|
} // namespace onnx_mlir
|