6e1de865bb
better validation scripts output big refactors
47 lines
1.1 KiB
C++
47 lines
1.1 KiB
C++
#pragma once
|
|
|
|
#include "mlir/IR/Value.h"
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include <cstddef>
|
|
#include <map>
|
|
|
|
namespace onnx_mlir {
|
|
|
|
/**
|
|
* @brief A helper struct to store a group of weights.
|
|
*
|
|
*/
|
|
struct TaggedWeights {
|
|
long inputTile;
|
|
long outputTile;
|
|
size_t startingCrossbarIndex;
|
|
llvm::SmallVector<mlir::Value> weights;
|
|
};
|
|
|
|
/**
|
|
* @brief A helper class to subdivide weights into groups.
|
|
*
|
|
* Weights are stored as a map of maps of SmallVectors. The outer map is indexed
|
|
* by input tile, the inner map is indexed by output tile, and the SmallVector
|
|
* contains the weights for the filter. This class allows us to extract groups
|
|
* of weights from the map until we've extracted a certain number of elements,
|
|
* namely as many as we need to fill a compute unit.
|
|
*/
|
|
class WeightSubdivider {
|
|
private:
|
|
std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights;
|
|
size_t crossbarsUsed = 0;
|
|
|
|
TaggedWeights popGroup(size_t amount);
|
|
|
|
public:
|
|
WeightSubdivider(std::map<long, std::map<long, llvm::SmallVector<mlir::Value>>> weights);
|
|
|
|
bool isEmpty() const;
|
|
llvm::SmallVector<TaggedWeights> popGroups(size_t n);
|
|
};
|
|
|
|
} // namespace onnx_mlir
|