Files
Raptor/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.cpp
NiccoloN 810e5e75f9 add .clang-format
reformat all src
2026-02-26 19:16:42 +01:00

54 lines
1.3 KiB
C++

#include <cassert>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
namespace onnx_mlir {
WeightSubdivider::WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights)
: weights(std::move(weights)) {}
bool WeightSubdivider::isEmpty() const { return weights.empty(); }
TaggedWeights WeightSubdivider::popGroup(size_t amount) {
assert(!weights.empty() && "No weights to extract.");
auto it = weights.begin();
SmallVector<Value>& values = it->second.begin()->second;
long inputTile = it->first;
long outputTile = it->second.begin()->first;
size_t n = std::min(amount, values.size());
crossbarsUsed += n;
SmallVector<Value> result;
result.assign(values.begin(), values.begin() + n);
if (n < values.size()) {
values.erase(values.begin(), values.begin() + n);
}
else {
it->second.erase(outputTile);
if (it->second.empty())
weights.erase(inputTile);
}
return {inputTile, outputTile, crossbarsUsed - n, result};
}
SmallVector<TaggedWeights> WeightSubdivider::popGroups(size_t n) {
crossbarsUsed = 0;
SmallVector<TaggedWeights> result;
size_t remaining = n;
while (remaining > 0 && !weights.empty()) {
auto group = popGroup(remaining);
result.push_back(group);
remaining -= group.weights.size();
}
return result;
}
} // namespace onnx_mlir