peft cost model

This commit is contained in:
ilgeco
2026-06-18 10:57:59 +02:00
parent e083c27d80
commit 4ab24eb288
@@ -12,6 +12,7 @@
#include "llvm/Support/Casting.h"
#include <algorithm>
#include <cmath>
#include <iterator>
#include <limits>
#include <optional>
@@ -21,6 +22,7 @@
#include "ComputeGraph.hpp"
#include "ComputeInstanceUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Support/TypeUtilities.hpp"
@@ -35,9 +37,223 @@ uint64_t countComputeBodyOperationInstances(Region& body);
namespace {
Cost getComputeBodyCost(Region& body) {
constexpr Cost kOperationCost = 100;
return checkedMultiply(static_cast<Cost>(countComputeBodyOperationInstances(body)), kOperationCost);
struct PimsimSchedulerCostModel {
static constexpr Cost kDefaultBitwidth = 8;
static constexpr Cost kCorePeriodNs = 1;
static constexpr Cost kLocalMemoryWidthBytes = 64;
static constexpr Cost kLocalMemoryLatencyCycles = 1;
static constexpr Cost kNetworkBusWidthBytes = 8;
static constexpr Cost kNetworkBaseLatencyNs = 2;
static constexpr Cost kNetworkPerHopLatencyNs = 1;
static constexpr Cost kVectorWidth = 16;
static constexpr Cost kVectorLatencyCycles = 4;
static constexpr Cost kDacResolutionBits = 1;
static constexpr Cost kDacLatencyCycles = 1;
static constexpr Cost kDacCount = 128;
static constexpr Cost kXbarReadLatencyNs = 30;
static constexpr Cost kSampleHoldLatencyCycles = 1;
static constexpr Cost kAdcLatencyCycles = 10;
static constexpr Cost kAdcCount = 2;
static constexpr Cost kShiftAdderLatencyCycles = 1;
static constexpr Cost kOutputBufferLatencyCycles = 1;
static constexpr Cost kInputBufferLatencyCycles = 0;
static constexpr Cost kFallbackOperationCost = 1;
static Cost ceilDiv(Cost numerator, Cost denominator) {
assert(denominator > 0 && "denominator must be positive");
return (numerator + denominator - 1) / denominator;
}
static std::optional<Cost> getStaticElementCount(Type type) {
auto shaped = dyn_cast<ShapedType>(type);
if (!shaped || !shaped.hasStaticShape())
return std::nullopt;
return static_cast<Cost>(shaped.getNumElements());
}
static Cost getBitwidthOrDefault(Type type) {
if (auto shaped = dyn_cast<ShapedType>(type))
type = shaped.getElementType();
if (auto intType = dyn_cast<IntegerType>(type))
return intType.getWidth();
if (auto floatType = dyn_cast<FloatType>(type))
return floatType.getWidth();
if (isa<IndexType>(type))
return 64;
return kDefaultBitwidth;
}
static Cost getComputeBitwidth(Type type) {
return std::min(getBitwidthOrDefault(type), kDefaultBitwidth);
}
static Cost getByteSize(Type type, Cost fallbackBitwidth = kDefaultBitwidth) {
auto elementCount = getStaticElementCount(type);
if (!elementCount)
return kFallbackOperationCost;
Cost bitwidth = fallbackBitwidth;
if (bitwidth <= 0)
bitwidth = getBitwidthOrDefault(type);
if (bitwidth <= 0)
bitwidth = kDefaultBitwidth;
return ceilDiv(checkedMultiply(*elementCount, bitwidth), static_cast<Cost>(8));
}
static Cost getVectorReadWriteCost(Cost readBytes, Cost writeBytes) {
Cost totalBytes = checkedAdd(readBytes, writeBytes);
return checkedMultiply(ceilDiv(totalBytes, kLocalMemoryWidthBytes),
checkedMultiply(kLocalMemoryLatencyCycles, kCorePeriodNs));
}
static Cost getVectorComputeCost(Cost elementCount) {
return checkedMultiply(ceilDiv(elementCount, kVectorWidth),
checkedMultiply(kVectorLatencyCycles, kCorePeriodNs));
}
static Cost getTensorMoveCost(Type type) {
return getVectorReadWriteCost(getByteSize(type), 0);
}
static std::pair<Cost, Cost> estimateMeshShape() {
Cost coreCount = static_cast<Cost>(std::max<long>(1, coresCount.getValue()));
Cost rows = static_cast<Cost>(std::sqrt(static_cast<long double>(coreCount)));
if (rows == 0)
rows = 1;
while (rows > 1 && coreCount % rows != 0)
--rows;
Cost cols = ceilDiv(coreCount, rows);
return {rows, cols};
}
static Cost getAverageInterCoreLatencyNs() {
auto [rows, cols] = estimateMeshShape();
auto averageAxisDistance = [](Cost size) -> Cost {
if (size <= 1)
return 0;
return checkedMultiply(size, size) - 1;
};
Cost avgRow = averageAxisDistance(rows) / (static_cast<Cost>(3) * rows);
Cost avgCol = averageAxisDistance(cols) / (static_cast<Cost>(3) * cols);
return checkedAdd(kNetworkBaseLatencyNs, checkedMultiply(kNetworkPerHopLatencyNs, checkedAdd(avgRow, avgCol)));
}
static Cost getInterCoreTransferCostFromBytes(Cost bytes) {
Cost localRead = checkedMultiply(ceilDiv(bytes, kLocalMemoryWidthBytes),
checkedMultiply(kLocalMemoryLatencyCycles, kCorePeriodNs));
Cost localWrite = checkedMultiply(ceilDiv(bytes, kLocalMemoryWidthBytes),
checkedMultiply(kLocalMemoryLatencyCycles, kCorePeriodNs));
Cost payloadFlits = ceilDiv(bytes, kNetworkBusWidthBytes);
Cost averageNoCLatency = getAverageInterCoreLatencyNs();
Cost network = checkedMultiply(checkedAdd(static_cast<Cost>(2), payloadFlits), averageNoCLatency);
return checkedAdd(checkedAdd(localRead, localWrite), network);
}
static Cost getUnaryVectorCost(Type inputType, Type outputType, bool scalarOutput = false) {
auto maybeElements = getStaticElementCount(inputType);
if (!maybeElements)
return kFallbackOperationCost;
Cost inputBytes = getByteSize(inputType, getComputeBitwidth(inputType));
Cost outputBytes = scalarOutput ? ceilDiv(getComputeBitwidth(outputType), static_cast<Cost>(8))
: getByteSize(outputType, getComputeBitwidth(outputType));
return checkedAdd(getVectorReadWriteCost(inputBytes, outputBytes), getVectorComputeCost(*maybeElements));
}
static Cost getBinaryVectorCost(Type lhsType, Type rhsType, Type outputType, bool scalarOutput = false) {
auto maybeElements = getStaticElementCount(lhsType);
if (!maybeElements)
return kFallbackOperationCost;
Cost readBytes = checkedAdd(getByteSize(lhsType, getComputeBitwidth(lhsType)),
getByteSize(rhsType, getComputeBitwidth(rhsType)));
Cost outputBytes = scalarOutput ? ceilDiv(getComputeBitwidth(outputType), static_cast<Cost>(8))
: getByteSize(outputType, getComputeBitwidth(outputType));
return checkedAdd(getVectorReadWriteCost(readBytes, outputBytes), getVectorComputeCost(*maybeElements));
}
static Cost getMatrixComputeLatency(Cost inputBitwidth) {
Cost xbarDim = static_cast<Cost>(crossbarSize.getValue());
Cost inputTimes = ceilDiv(inputBitwidth, kDacResolutionBits);
Cost dacTimes = ceilDiv(xbarDim, kDacCount);
Cost adcTimes = ceilDiv(xbarDim, kAdcCount);
Cost frontStage = kInputBufferLatencyCycles + kDacLatencyCycles + kXbarReadLatencyNs + kSampleHoldLatencyCycles;
Cost backPipe = std::max(kAdcLatencyCycles, checkedAdd(kShiftAdderLatencyCycles, kOutputBufferLatencyCycles));
Cost backStage = checkedAdd(checkedAdd(kAdcLatencyCycles, kShiftAdderLatencyCycles), kOutputBufferLatencyCycles);
backStage = checkedAdd(backStage, checkedMultiply(adcTimes - 1, backPipe));
Cost totalTimes = checkedMultiply(inputTimes, dacTimes);
Cost stagePipe = std::max(frontStage, backStage);
return checkedAdd(checkedAdd(frontStage, backStage),
checkedMultiply(totalTimes - 1, stagePipe));
}
static Cost getWvmmCost(Type inputType, Type outputType) {
Cost inputBitwidth = getComputeBitwidth(inputType);
Cost inputBytes = checkedMultiply(static_cast<Cost>(crossbarSize.getValue()),
ceilDiv(inputBitwidth, static_cast<Cost>(8)));
inputBytes = checkedMultiply(inputBytes, static_cast<Cost>(8));
Cost outputBytes = getByteSize(outputType, getComputeBitwidth(outputType));
return checkedAdd(getVectorReadWriteCost(inputBytes, outputBytes), getMatrixComputeLatency(inputBitwidth));
}
};
std::optional<uint64_t> getStaticTripCount(scf::ForOp loop);
[[maybe_unused]] Cost getOperationCost(Operation& op);
[[maybe_unused]] Cost getRegionCost(Region& body) {
Cost cost = 0;
for (Block& block : body)
for (Operation& op : block)
cost = checkedAdd(cost, getOperationCost(op));
return cost;
}
[[maybe_unused]] Cost getOperationCost(Operation& op) {
if (auto loop = dyn_cast<scf::ForOp>(&op)) {
std::optional<uint64_t> tripCount = getStaticTripCount(loop);
if (!tripCount)
return PimsimSchedulerCostModel::kFallbackOperationCost;
return checkedMultiply(getRegionCost(loop.getRegion()), static_cast<Cost>(*tripCount));
}
if (isa<SpatYieldOp, SpatInParallelOp, affine::AffineApplyOp, arith::ConstantOp,
tensor::EmptyOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp>(&op))
return 0;
if (auto wvmm = dyn_cast<SpatVMMOp>(&op))
return PimsimSchedulerCostModel::getWvmmCost(wvmm.getInput().getType(), wvmm.getOutput().getType());
if (auto vvdmul = dyn_cast<SpatVVDMulOp>(&op))
return PimsimSchedulerCostModel::getBinaryVectorCost(
vvdmul.getLhs().getType(), vvdmul.getRhs().getType(), vvdmul.getOutput().getType(), /*scalarOutput=*/true);
if (auto vadd = dyn_cast<SpatVAddOp>(&op))
return PimsimSchedulerCostModel::getBinaryVectorCost(vadd.getLhs().getType(), vadd.getRhs().getType(),
vadd.getOutput().getType());
if (auto vsub = dyn_cast<SpatVSubOp>(&op))
return PimsimSchedulerCostModel::getBinaryVectorCost(vsub.getLhs().getType(), vsub.getRhs().getType(),
vsub.getOutput().getType());
if (auto vmul = dyn_cast<SpatVMulOp>(&op))
return PimsimSchedulerCostModel::getBinaryVectorCost(vmul.getLhs().getType(), vmul.getRhs().getType(),
vmul.getOutput().getType());
if (auto vmax = dyn_cast<SpatVMaxOp>(&op))
return PimsimSchedulerCostModel::getBinaryVectorCost(vmax.getLhs().getType(), vmax.getRhs().getType(),
vmax.getOutput().getType());
if (auto vavg = dyn_cast<SpatVAvgOp>(&op))
return PimsimSchedulerCostModel::getUnaryVectorCost(vavg.getInput().getType(), vavg.getOutput().getType(),
/*scalarOutput=*/true);
if (auto relu = dyn_cast<SpatReluOp>(&op))
return PimsimSchedulerCostModel::getUnaryVectorCost(relu.getInput().getType(), relu.getOutput().getType());
if (auto sigm = dyn_cast<SpatSigmoidOp>(&op))
return PimsimSchedulerCostModel::getUnaryVectorCost(sigm.getInput().getType(), sigm.getOutput().getType());
if (auto softmax = dyn_cast<SpatSoftmaxOp>(&op)) {
Cost unary = PimsimSchedulerCostModel::getUnaryVectorCost(softmax.getInput().getType(), softmax.getOutput().getType());
return checkedMultiply(unary, static_cast<Cost>(4));
}
if (auto extract = dyn_cast<tensor::ExtractSliceOp>(&op))
return PimsimSchedulerCostModel::getTensorMoveCost(extract.getResult().getType());
if (auto insert = dyn_cast<tensor::InsertSliceOp>(&op))
return PimsimSchedulerCostModel::getTensorMoveCost(insert.getSource().getType());
Cost nestedCost = 0;
for (Region& region : op.getRegions())
nestedCost = checkedAdd(nestedCost, getRegionCost(region));
return checkedAdd(PimsimSchedulerCostModel::kFallbackOperationCost, nestedCost);
}
std::optional<uint64_t> getStaticTripCount(scf::ForOp loop) {
@@ -54,6 +270,11 @@ std::optional<uint64_t> getStaticTripCount(scf::ForOp loop) {
return (distance + stride - 1) / stride;
}
Cost getComputeBodyCost(Region& body) {
constexpr Cost kOperationCost = 100;
return checkedMultiply(static_cast<Cost>(countComputeBodyOperationInstances(body)), kOperationCost);
}
uint64_t countOperationInstances(Operation& op) {
if (auto loop = dyn_cast<scf::ForOp>(&op)) {
std::optional<uint64_t> tripCount = getStaticTripCount(loop);
@@ -149,7 +370,8 @@ std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, V
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
if (!resultType || !resultType.hasStaticShape())
return std::nullopt;
projectedCost = checkedAdd(projectedCost, static_cast<Cost>(getSizeInBytes(resultType)));
projectedCost = checkedAdd(
projectedCost, PimsimSchedulerCostModel::getInterCoreTransferCostFromBytes(static_cast<Cost>(getSizeInBytes(resultType))));
}
if (projectedCost == 0)
@@ -162,7 +384,7 @@ Cost getInputTransferCost(const ComputeInstance& consumerInstance, Value input)
if (auto batch = dyn_cast<SpatComputeBatch>(consumerInstance.op))
if (std::optional<Cost> projectedCost = getBatchProjectedInputTransferCost(batch, input))
return *projectedCost;
return static_cast<Cost>(getSizeInBytes(inputType));
return PimsimSchedulerCostModel::getInterCoreTransferCostFromBytes(static_cast<Cost>(getSizeInBytes(inputType)));
}
uint32_t getLaneOverlapCount(const ComputeInstance& lhs, const ComputeInstance& rhs) {
@@ -451,7 +673,7 @@ std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> ed
continue;
auto inserted = edgeCosts.try_emplace({edge.source, edge.target}, edge.transferCost);
if (!inserted.second)
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
inserted.first->second = checkedAdd(inserted.first->second, edge.transferCost);
}
std::vector<ComputeGraphEdge> aggregatedEdges;