#pragma once #include "mlir/IR/Value.h" #include "llvm/ADT/SmallVector.h" #include #include namespace onnx_mlir { /** * @brief A helper struct to store a group of weights. * */ struct TaggedWeights { long inputTile; long outputTile; size_t startingCrossbarIndex; llvm::SmallVector weights; }; /** * @brief A helper class to subdivide weights into groups. * * Weights are stored as a map of maps of SmallVectors. The outer map is indexed * by input tile, the inner map is indexed by output tile, and the SmallVector * contains the weights for the filter. This class allows us to extract groups * of weights from the map until we've extracted a certain number of elements, * namely as many as we need to fill a compute unit. */ class WeightSubdivider { private: std::map>> weights; size_t crossbarsUsed = 0; TaggedWeights popGroup(size_t amount); public: WeightSubdivider(std::map>> weights); bool isEmpty() const; llvm::SmallVector popGroups(size_t n); }; } // namespace onnx_mlir