87 lines
2.6 KiB
C++
87 lines
2.6 KiB
C++
#include <cassert>
|
|
#include <cstdlib>
|
|
#include <iostream>
|
|
|
|
#include "src/Accelerators/PIM/Compiler/PimMemoryLiveness.hpp"
|
|
|
|
using onnx_mlir::LocalAllocInterval;
|
|
using onnx_mlir::planPhysicalSlots;
|
|
|
|
namespace {
|
|
|
|
LocalAllocInterval makeInterval(size_t id, size_t size, uint64_t start, uint64_t end) {
|
|
LocalAllocInterval interval;
|
|
interval.id = id;
|
|
interval.size = size;
|
|
interval.start = start;
|
|
interval.end = end;
|
|
return interval;
|
|
}
|
|
|
|
void assertSingleSlotCase(LocalAllocInterval a, LocalAllocInterval b, size_t expectedSlotSize) {
|
|
llvm::SmallVector<LocalAllocInterval, 4> intervals = {a, b};
|
|
auto slots = planPhysicalSlots(intervals);
|
|
assert(slots.size() == 1);
|
|
assert(slots.front().requiredSize == expectedSlotSize);
|
|
assert(intervals[0].physicalSlotId == intervals[1].physicalSlotId);
|
|
}
|
|
|
|
int testSameSizeNonOverlap() {
|
|
std::cout << "testSameSizeNonOverlap:" << std::endl;
|
|
assertSingleSlotCase(makeInterval(0, 64, 0, 10), makeInterval(1, 64, 11, 20), 64);
|
|
return 0;
|
|
}
|
|
|
|
int testLargerFirst() {
|
|
std::cout << "testLargerFirst:" << std::endl;
|
|
assertSingleSlotCase(makeInterval(0, 100, 0, 10), makeInterval(1, 40, 11, 20), 100);
|
|
return 0;
|
|
}
|
|
|
|
int testSmallerFirst() {
|
|
std::cout << "testSmallerFirst:" << std::endl;
|
|
assertSingleSlotCase(makeInterval(0, 40, 0, 10), makeInterval(1, 100, 11, 20), 100);
|
|
return 0;
|
|
}
|
|
|
|
int testOverlapNeedsTwoSlots() {
|
|
std::cout << "testOverlapNeedsTwoSlots:" << std::endl;
|
|
llvm::SmallVector<LocalAllocInterval, 4> intervals = {
|
|
makeInterval(0, 100, 0, 20), makeInterval(1, 40, 10, 30)};
|
|
auto slots = planPhysicalSlots(intervals);
|
|
assert(slots.size() == 2);
|
|
assert(intervals[0].physicalSlotId != intervals[1].physicalSlotId);
|
|
return 0;
|
|
}
|
|
|
|
int testReuseChain() {
|
|
std::cout << "testReuseChain:" << std::endl;
|
|
llvm::SmallVector<LocalAllocInterval, 4> intervals = {
|
|
makeInterval(0, 40, 0, 10), makeInterval(1, 100, 11, 20), makeInterval(2, 20, 21, 30)};
|
|
auto slots = planPhysicalSlots(intervals);
|
|
assert(slots.size() == 1);
|
|
assert(slots.front().requiredSize == 100);
|
|
assert(intervals[0].physicalSlotId == intervals[1].physicalSlotId);
|
|
assert(intervals[1].physicalSlotId == intervals[2].physicalSlotId);
|
|
return 0;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
int main(int argc, char *argv[]) {
|
|
(void) argc;
|
|
(void) argv;
|
|
|
|
int failures = 0;
|
|
failures += testSameSizeNonOverlap();
|
|
failures += testLargerFirst();
|
|
failures += testSmallerFirst();
|
|
failures += testOverlapNeedsTwoSlots();
|
|
failures += testReuseChain();
|
|
if (failures != 0) {
|
|
std::cerr << failures << " test failures\n";
|
|
return EXIT_FAILURE;
|
|
}
|
|
return EXIT_SUCCESS;
|
|
}
|