Files
Raptor/test/PIM/PimMemoryLivenessPlannerTest.cpp
T
2026-06-03 18:15:30 +02:00

87 lines
2.6 KiB
C++

#include <cassert>
#include <cstdlib>
#include <iostream>
#include "src/Accelerators/PIM/Compiler/PimMemoryLiveness.hpp"
using onnx_mlir::LocalAllocInterval;
using onnx_mlir::planPhysicalSlots;
namespace {
LocalAllocInterval makeInterval(size_t id, size_t size, uint64_t start, uint64_t end) {
LocalAllocInterval interval;
interval.id = id;
interval.size = size;
interval.start = start;
interval.end = end;
return interval;
}
void assertSingleSlotCase(LocalAllocInterval a, LocalAllocInterval b, size_t expectedSlotSize) {
llvm::SmallVector<LocalAllocInterval, 4> intervals = {a, b};
auto slots = planPhysicalSlots(intervals);
assert(slots.size() == 1);
assert(slots.front().requiredSize == expectedSlotSize);
assert(intervals[0].physicalSlotId == intervals[1].physicalSlotId);
}
int testSameSizeNonOverlap() {
std::cout << "testSameSizeNonOverlap:" << std::endl;
assertSingleSlotCase(makeInterval(0, 64, 0, 10), makeInterval(1, 64, 11, 20), 64);
return 0;
}
int testLargerFirst() {
std::cout << "testLargerFirst:" << std::endl;
assertSingleSlotCase(makeInterval(0, 100, 0, 10), makeInterval(1, 40, 11, 20), 100);
return 0;
}
int testSmallerFirst() {
std::cout << "testSmallerFirst:" << std::endl;
assertSingleSlotCase(makeInterval(0, 40, 0, 10), makeInterval(1, 100, 11, 20), 100);
return 0;
}
int testOverlapNeedsTwoSlots() {
std::cout << "testOverlapNeedsTwoSlots:" << std::endl;
llvm::SmallVector<LocalAllocInterval, 4> intervals = {
makeInterval(0, 100, 0, 20), makeInterval(1, 40, 10, 30)};
auto slots = planPhysicalSlots(intervals);
assert(slots.size() == 2);
assert(intervals[0].physicalSlotId != intervals[1].physicalSlotId);
return 0;
}
int testReuseChain() {
std::cout << "testReuseChain:" << std::endl;
llvm::SmallVector<LocalAllocInterval, 4> intervals = {
makeInterval(0, 40, 0, 10), makeInterval(1, 100, 11, 20), makeInterval(2, 20, 21, 30)};
auto slots = planPhysicalSlots(intervals);
assert(slots.size() == 1);
assert(slots.front().requiredSize == 100);
assert(intervals[0].physicalSlotId == intervals[1].physicalSlotId);
assert(intervals[1].physicalSlotId == intervals[2].physicalSlotId);
return 0;
}
} // namespace
int main(int argc, char *argv[]) {
(void) argc;
(void) argv;
int failures = 0;
failures += testSameSizeNonOverlap();
failures += testLargerFirst();
failures += testSmallerFirst();
failures += testOverlapNeedsTwoSlots();
failures += testReuseChain();
if (failures != 0) {
std::cerr << failures << " test failures\n";
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}