#include #include #include #include "src/Accelerators/PIM/Compiler/PimMemoryLiveness.hpp" using onnx_mlir::LocalAllocInterval; using onnx_mlir::planPhysicalSlots; namespace { LocalAllocInterval makeInterval(size_t id, size_t size, uint64_t start, uint64_t end) { LocalAllocInterval interval; interval.id = id; interval.size = size; interval.start = start; interval.end = end; return interval; } void assertSingleSlotCase(LocalAllocInterval a, LocalAllocInterval b, size_t expectedSlotSize) { llvm::SmallVector intervals = {a, b}; auto slots = planPhysicalSlots(intervals); assert(slots.size() == 1); assert(slots.front().requiredSize == expectedSlotSize); assert(intervals[0].physicalSlotId == intervals[1].physicalSlotId); } int testSameSizeNonOverlap() { std::cout << "testSameSizeNonOverlap:" << std::endl; assertSingleSlotCase(makeInterval(0, 64, 0, 10), makeInterval(1, 64, 11, 20), 64); return 0; } int testLargerFirst() { std::cout << "testLargerFirst:" << std::endl; assertSingleSlotCase(makeInterval(0, 100, 0, 10), makeInterval(1, 40, 11, 20), 100); return 0; } int testSmallerFirst() { std::cout << "testSmallerFirst:" << std::endl; assertSingleSlotCase(makeInterval(0, 40, 0, 10), makeInterval(1, 100, 11, 20), 100); return 0; } int testOverlapNeedsTwoSlots() { std::cout << "testOverlapNeedsTwoSlots:" << std::endl; llvm::SmallVector intervals = { makeInterval(0, 100, 0, 20), makeInterval(1, 40, 10, 30)}; auto slots = planPhysicalSlots(intervals); assert(slots.size() == 2); assert(intervals[0].physicalSlotId != intervals[1].physicalSlotId); return 0; } int testReuseChain() { std::cout << "testReuseChain:" << std::endl; llvm::SmallVector intervals = { makeInterval(0, 40, 0, 10), makeInterval(1, 100, 11, 20), makeInterval(2, 20, 21, 30)}; auto slots = planPhysicalSlots(intervals); assert(slots.size() == 1); assert(slots.front().requiredSize == 100); assert(intervals[0].physicalSlotId == intervals[1].physicalSlotId); assert(intervals[1].physicalSlotId == intervals[2].physicalSlotId); return 0; } } // namespace int main(int argc, char *argv[]) { (void) argc; (void) argv; int failures = 0; failures += testSameSizeNonOverlap(); failures += testLargerFirst(); failures += testSmallerFirst(); failures += testOverlapNeedsTwoSlots(); failures += testReuseChain(); if (failures != 0) { std::cerr << failures << " test failures\n"; return EXIT_FAILURE; } return EXIT_SUCCESS; }