54 lines
1.3 KiB
C++
54 lines
1.3 KiB
C++
#include <cassert>
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
|
|
|
|
namespace onnx_mlir {
|
|
|
|
WeightSubdivider::WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights)
|
|
: weights(std::move(weights)) {}
|
|
|
|
bool WeightSubdivider::isEmpty() const { return weights.empty(); }
|
|
|
|
TaggedWeights WeightSubdivider::popGroup(size_t amount) {
|
|
assert(!weights.empty() && "No weights to extract.");
|
|
|
|
auto it = weights.begin();
|
|
SmallVector<Value>& values = it->second.begin()->second;
|
|
|
|
long inputTile = it->first;
|
|
long outputTile = it->second.begin()->first;
|
|
|
|
size_t n = std::min(amount, values.size());
|
|
crossbarsUsed += n;
|
|
|
|
SmallVector<Value> result;
|
|
result.assign(values.begin(), values.begin() + n);
|
|
|
|
if (n < values.size()) {
|
|
values.erase(values.begin(), values.begin() + n);
|
|
}
|
|
else {
|
|
it->second.erase(outputTile);
|
|
if (it->second.empty())
|
|
weights.erase(inputTile);
|
|
}
|
|
|
|
return {inputTile, outputTile, crossbarsUsed - n, result};
|
|
}
|
|
|
|
SmallVector<TaggedWeights> WeightSubdivider::popGroups(size_t n) {
|
|
crossbarsUsed = 0;
|
|
SmallVector<TaggedWeights> result;
|
|
size_t remaining = n;
|
|
|
|
while (remaining > 0 && !weights.empty()) {
|
|
auto group = popGroup(remaining);
|
|
result.push_back(group);
|
|
remaining -= group.weights.size();
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
} // namespace onnx_mlir
|