Files
Raptor/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp
T
NiccoloN 6e1de865bb add constant folding and verification pass for pim host operations
better validation scripts output
big refactors
2026-03-20 12:08:12 +01:00

47 lines
1.1 KiB
C++

#pragma once
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <map>
namespace onnx_mlir {
/**
* @brief A helper struct to store a group of weights.
*
*/
struct TaggedWeights {
long inputTile;
long outputTile;
size_t startingCrossbarIndex;
llvm::SmallVector<mlir::Value> weights;
};
/**
* @brief A helper class to subdivide weights into groups.
*
* Weights are stored as a map of maps of SmallVectors. The outer map is indexed
* by input tile, the inner map is indexed by output tile, and the SmallVector
* contains the weights for the filter. This class allows us to extract groups
* of weights from the map until we've extracted a certain number of elements,
* namely as many as we need to fill a compute unit.
*/
class WeightSubdivider {
private:
std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights;
size_t crossbarsUsed = 0;
TaggedWeights popGroup(size_t amount);
public:
WeightSubdivider(std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights);
bool isEmpty() const;
llvm::SmallVector<TaggedWeights> popGroups(size_t n);
};
} // namespace onnx_mlir