diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp index ef579ab..36796e4 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp @@ -12,6 +12,7 @@ #include "llvm/Support/Casting.h" #include +#include #include #include #include @@ -21,6 +22,7 @@ #include "ComputeGraph.hpp" #include "ComputeInstanceUtils.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Support/TypeUtilities.hpp" @@ -35,9 +37,223 @@ uint64_t countComputeBodyOperationInstances(Region& body); namespace { -Cost getComputeBodyCost(Region& body) { - constexpr Cost kOperationCost = 100; - return checkedMultiply(static_cast(countComputeBodyOperationInstances(body)), kOperationCost); +struct PimsimSchedulerCostModel { + static constexpr Cost kDefaultBitwidth = 8; + static constexpr Cost kCorePeriodNs = 1; + static constexpr Cost kLocalMemoryWidthBytes = 64; + static constexpr Cost kLocalMemoryLatencyCycles = 1; + static constexpr Cost kNetworkBusWidthBytes = 8; + static constexpr Cost kNetworkBaseLatencyNs = 2; + static constexpr Cost kNetworkPerHopLatencyNs = 1; + static constexpr Cost kVectorWidth = 16; + static constexpr Cost kVectorLatencyCycles = 4; + static constexpr Cost kDacResolutionBits = 1; + static constexpr Cost kDacLatencyCycles = 1; + static constexpr Cost kDacCount = 128; + static constexpr Cost kXbarReadLatencyNs = 30; + static constexpr Cost kSampleHoldLatencyCycles = 1; + static constexpr Cost kAdcLatencyCycles = 10; + static constexpr Cost kAdcCount = 2; + static constexpr Cost kShiftAdderLatencyCycles = 1; + static constexpr Cost kOutputBufferLatencyCycles = 1; + static constexpr Cost kInputBufferLatencyCycles = 0; + static constexpr Cost kFallbackOperationCost = 1; + + static Cost ceilDiv(Cost numerator, Cost denominator) { + assert(denominator > 0 && "denominator must be positive"); + return (numerator + denominator - 1) / denominator; + } + + static std::optional getStaticElementCount(Type type) { + auto shaped = dyn_cast(type); + if (!shaped || !shaped.hasStaticShape()) + return std::nullopt; + return static_cast(shaped.getNumElements()); + } + + static Cost getBitwidthOrDefault(Type type) { + if (auto shaped = dyn_cast(type)) + type = shaped.getElementType(); + if (auto intType = dyn_cast(type)) + return intType.getWidth(); + if (auto floatType = dyn_cast(type)) + return floatType.getWidth(); + if (isa(type)) + return 64; + return kDefaultBitwidth; + } + + static Cost getComputeBitwidth(Type type) { + return std::min(getBitwidthOrDefault(type), kDefaultBitwidth); + } + + static Cost getByteSize(Type type, Cost fallbackBitwidth = kDefaultBitwidth) { + auto elementCount = getStaticElementCount(type); + if (!elementCount) + return kFallbackOperationCost; + Cost bitwidth = fallbackBitwidth; + if (bitwidth <= 0) + bitwidth = getBitwidthOrDefault(type); + if (bitwidth <= 0) + bitwidth = kDefaultBitwidth; + return ceilDiv(checkedMultiply(*elementCount, bitwidth), static_cast(8)); + } + + static Cost getVectorReadWriteCost(Cost readBytes, Cost writeBytes) { + Cost totalBytes = checkedAdd(readBytes, writeBytes); + return checkedMultiply(ceilDiv(totalBytes, kLocalMemoryWidthBytes), + checkedMultiply(kLocalMemoryLatencyCycles, kCorePeriodNs)); + } + + static Cost getVectorComputeCost(Cost elementCount) { + return checkedMultiply(ceilDiv(elementCount, kVectorWidth), + checkedMultiply(kVectorLatencyCycles, kCorePeriodNs)); + } + + static Cost getTensorMoveCost(Type type) { + return getVectorReadWriteCost(getByteSize(type), 0); + } + + static std::pair estimateMeshShape() { + Cost coreCount = static_cast(std::max(1, coresCount.getValue())); + Cost rows = static_cast(std::sqrt(static_cast(coreCount))); + if (rows == 0) + rows = 1; + while (rows > 1 && coreCount % rows != 0) + --rows; + Cost cols = ceilDiv(coreCount, rows); + return {rows, cols}; + } + + static Cost getAverageInterCoreLatencyNs() { + auto [rows, cols] = estimateMeshShape(); + auto averageAxisDistance = [](Cost size) -> Cost { + if (size <= 1) + return 0; + return checkedMultiply(size, size) - 1; + }; + Cost avgRow = averageAxisDistance(rows) / (static_cast(3) * rows); + Cost avgCol = averageAxisDistance(cols) / (static_cast(3) * cols); + return checkedAdd(kNetworkBaseLatencyNs, checkedMultiply(kNetworkPerHopLatencyNs, checkedAdd(avgRow, avgCol))); + } + + static Cost getInterCoreTransferCostFromBytes(Cost bytes) { + Cost localRead = checkedMultiply(ceilDiv(bytes, kLocalMemoryWidthBytes), + checkedMultiply(kLocalMemoryLatencyCycles, kCorePeriodNs)); + Cost localWrite = checkedMultiply(ceilDiv(bytes, kLocalMemoryWidthBytes), + checkedMultiply(kLocalMemoryLatencyCycles, kCorePeriodNs)); + Cost payloadFlits = ceilDiv(bytes, kNetworkBusWidthBytes); + Cost averageNoCLatency = getAverageInterCoreLatencyNs(); + Cost network = checkedMultiply(checkedAdd(static_cast(2), payloadFlits), averageNoCLatency); + return checkedAdd(checkedAdd(localRead, localWrite), network); + } + + static Cost getUnaryVectorCost(Type inputType, Type outputType, bool scalarOutput = false) { + auto maybeElements = getStaticElementCount(inputType); + if (!maybeElements) + return kFallbackOperationCost; + Cost inputBytes = getByteSize(inputType, getComputeBitwidth(inputType)); + Cost outputBytes = scalarOutput ? ceilDiv(getComputeBitwidth(outputType), static_cast(8)) + : getByteSize(outputType, getComputeBitwidth(outputType)); + return checkedAdd(getVectorReadWriteCost(inputBytes, outputBytes), getVectorComputeCost(*maybeElements)); + } + + static Cost getBinaryVectorCost(Type lhsType, Type rhsType, Type outputType, bool scalarOutput = false) { + auto maybeElements = getStaticElementCount(lhsType); + if (!maybeElements) + return kFallbackOperationCost; + Cost readBytes = checkedAdd(getByteSize(lhsType, getComputeBitwidth(lhsType)), + getByteSize(rhsType, getComputeBitwidth(rhsType))); + Cost outputBytes = scalarOutput ? ceilDiv(getComputeBitwidth(outputType), static_cast(8)) + : getByteSize(outputType, getComputeBitwidth(outputType)); + return checkedAdd(getVectorReadWriteCost(readBytes, outputBytes), getVectorComputeCost(*maybeElements)); + } + + static Cost getMatrixComputeLatency(Cost inputBitwidth) { + Cost xbarDim = static_cast(crossbarSize.getValue()); + Cost inputTimes = ceilDiv(inputBitwidth, kDacResolutionBits); + Cost dacTimes = ceilDiv(xbarDim, kDacCount); + Cost adcTimes = ceilDiv(xbarDim, kAdcCount); + Cost frontStage = kInputBufferLatencyCycles + kDacLatencyCycles + kXbarReadLatencyNs + kSampleHoldLatencyCycles; + Cost backPipe = std::max(kAdcLatencyCycles, checkedAdd(kShiftAdderLatencyCycles, kOutputBufferLatencyCycles)); + Cost backStage = checkedAdd(checkedAdd(kAdcLatencyCycles, kShiftAdderLatencyCycles), kOutputBufferLatencyCycles); + backStage = checkedAdd(backStage, checkedMultiply(adcTimes - 1, backPipe)); + Cost totalTimes = checkedMultiply(inputTimes, dacTimes); + Cost stagePipe = std::max(frontStage, backStage); + return checkedAdd(checkedAdd(frontStage, backStage), + checkedMultiply(totalTimes - 1, stagePipe)); + } + + static Cost getWvmmCost(Type inputType, Type outputType) { + Cost inputBitwidth = getComputeBitwidth(inputType); + Cost inputBytes = checkedMultiply(static_cast(crossbarSize.getValue()), + ceilDiv(inputBitwidth, static_cast(8))); + inputBytes = checkedMultiply(inputBytes, static_cast(8)); + Cost outputBytes = getByteSize(outputType, getComputeBitwidth(outputType)); + return checkedAdd(getVectorReadWriteCost(inputBytes, outputBytes), getMatrixComputeLatency(inputBitwidth)); + } +}; + +std::optional getStaticTripCount(scf::ForOp loop); +[[maybe_unused]] Cost getOperationCost(Operation& op); + +[[maybe_unused]] Cost getRegionCost(Region& body) { + Cost cost = 0; + for (Block& block : body) + for (Operation& op : block) + cost = checkedAdd(cost, getOperationCost(op)); + return cost; +} + +[[maybe_unused]] Cost getOperationCost(Operation& op) { + if (auto loop = dyn_cast(&op)) { + std::optional tripCount = getStaticTripCount(loop); + if (!tripCount) + return PimsimSchedulerCostModel::kFallbackOperationCost; + return checkedMultiply(getRegionCost(loop.getRegion()), static_cast(*tripCount)); + } + + if (isa(&op)) + return 0; + + if (auto wvmm = dyn_cast(&op)) + return PimsimSchedulerCostModel::getWvmmCost(wvmm.getInput().getType(), wvmm.getOutput().getType()); + if (auto vvdmul = dyn_cast(&op)) + return PimsimSchedulerCostModel::getBinaryVectorCost( + vvdmul.getLhs().getType(), vvdmul.getRhs().getType(), vvdmul.getOutput().getType(), /*scalarOutput=*/true); + if (auto vadd = dyn_cast(&op)) + return PimsimSchedulerCostModel::getBinaryVectorCost(vadd.getLhs().getType(), vadd.getRhs().getType(), + vadd.getOutput().getType()); + if (auto vsub = dyn_cast(&op)) + return PimsimSchedulerCostModel::getBinaryVectorCost(vsub.getLhs().getType(), vsub.getRhs().getType(), + vsub.getOutput().getType()); + if (auto vmul = dyn_cast(&op)) + return PimsimSchedulerCostModel::getBinaryVectorCost(vmul.getLhs().getType(), vmul.getRhs().getType(), + vmul.getOutput().getType()); + if (auto vmax = dyn_cast(&op)) + return PimsimSchedulerCostModel::getBinaryVectorCost(vmax.getLhs().getType(), vmax.getRhs().getType(), + vmax.getOutput().getType()); + if (auto vavg = dyn_cast(&op)) + return PimsimSchedulerCostModel::getUnaryVectorCost(vavg.getInput().getType(), vavg.getOutput().getType(), + /*scalarOutput=*/true); + if (auto relu = dyn_cast(&op)) + return PimsimSchedulerCostModel::getUnaryVectorCost(relu.getInput().getType(), relu.getOutput().getType()); + if (auto sigm = dyn_cast(&op)) + return PimsimSchedulerCostModel::getUnaryVectorCost(sigm.getInput().getType(), sigm.getOutput().getType()); + if (auto softmax = dyn_cast(&op)) { + Cost unary = PimsimSchedulerCostModel::getUnaryVectorCost(softmax.getInput().getType(), softmax.getOutput().getType()); + return checkedMultiply(unary, static_cast(4)); + } + if (auto extract = dyn_cast(&op)) + return PimsimSchedulerCostModel::getTensorMoveCost(extract.getResult().getType()); + if (auto insert = dyn_cast(&op)) + return PimsimSchedulerCostModel::getTensorMoveCost(insert.getSource().getType()); + + Cost nestedCost = 0; + for (Region& region : op.getRegions()) + nestedCost = checkedAdd(nestedCost, getRegionCost(region)); + return checkedAdd(PimsimSchedulerCostModel::kFallbackOperationCost, nestedCost); } std::optional getStaticTripCount(scf::ForOp loop) { @@ -54,6 +270,11 @@ std::optional getStaticTripCount(scf::ForOp loop) { return (distance + stride - 1) / stride; } +Cost getComputeBodyCost(Region& body) { + constexpr Cost kOperationCost = 100; + return checkedMultiply(static_cast(countComputeBodyOperationInstances(body)), kOperationCost); +} + uint64_t countOperationInstances(Operation& op) { if (auto loop = dyn_cast(&op)) { std::optional tripCount = getStaticTripCount(loop); @@ -149,7 +370,8 @@ std::optional getBatchProjectedInputTransferCost(SpatComputeBatch batch, V auto resultType = dyn_cast(extract.getResult().getType()); if (!resultType || !resultType.hasStaticShape()) return std::nullopt; - projectedCost = checkedAdd(projectedCost, static_cast(getSizeInBytes(resultType))); + projectedCost = checkedAdd( + projectedCost, PimsimSchedulerCostModel::getInterCoreTransferCostFromBytes(static_cast(getSizeInBytes(resultType)))); } if (projectedCost == 0) @@ -162,7 +384,7 @@ Cost getInputTransferCost(const ComputeInstance& consumerInstance, Value input) if (auto batch = dyn_cast(consumerInstance.op)) if (std::optional projectedCost = getBatchProjectedInputTransferCost(batch, input)) return *projectedCost; - return static_cast(getSizeInBytes(inputType)); + return PimsimSchedulerCostModel::getInterCoreTransferCostFromBytes(static_cast(getSizeInBytes(inputType))); } uint32_t getLaneOverlapCount(const ComputeInstance& lhs, const ComputeInstance& rhs) { @@ -451,7 +673,7 @@ std::vector aggregateEdges(llvm::ArrayRef ed continue; auto inserted = edgeCosts.try_emplace({edge.source, edge.target}, edge.transferCost); if (!inserted.second) - inserted.first->second = std::max(inserted.first->second, edge.transferCost); + inserted.first->second = checkedAdd(inserted.first->second, edge.transferCost); } std::vector aggregatedEdges;