teh only weight (WIP)
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-26 18:42:14 +02:00
parent addfc8a86e
commit d609e84054
17 changed files with 1031 additions and 630 deletions
+20
View File
@@ -40,6 +40,21 @@ Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, O
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
}
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, RewriterBase& rewriter) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getHostConstantBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
continue;
return constantOp.getResult();
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(hostBlock);
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
}
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
}
@@ -49,6 +64,11 @@ Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, Operation
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
}
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, RewriterBase& rewriter) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), rewriter);
}
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
+8 -4
View File
@@ -1,10 +1,7 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/FoldUtils.h"
@@ -17,10 +14,17 @@ mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
mlir::Type type,
mlir::OperationFolder& folder);
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
mlir::Attribute value,
mlir::Type type,
mlir::RewriterBase& rewriter);
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::RewriterBase& rewriter);
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
@@ -1,3 +1,4 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -12,13 +13,12 @@
#include "Common/Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
@@ -118,6 +118,7 @@ void ONNXToSpatialPass::runOnOperation() {
preTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
@@ -156,6 +157,7 @@ void ONNXToSpatialPass::runOnOperation() {
target.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
target.addIllegalOp<ONNXMatMulOp>();
@@ -189,6 +191,7 @@ void ONNXToSpatialPass::runOnOperation() {
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
@@ -203,6 +206,7 @@ void ONNXToSpatialPass::runOnOperation() {
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
@@ -1,6 +1,4 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Support/LLVM.h"
#include "Common/IR/WeightUtils.hpp"
@@ -13,17 +11,28 @@ using namespace mlir;
namespace onnx_mlir {
void checkWeightsDirectlyExtracted(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
for (auto extractSlice : func.getOps<tensor::ExtractSliceOp>()) {
auto source = getCompileTimeSource(extractSlice.getOperation());
if (source && hasWeightAlways(source->source) && source->chainLength > 1) {
namespace {
diagnostics.report(extractSlice.getOperation(),
[](Operation* illegalOp) { illegalOp->emitOpError("Weight not directly extracted"); });
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
func.walk([&](Operation* op) {
if (!hasWeightAlways(op))
return;
for (Value result : op->getResults()) {
if (hasOnlySpatialMvmVmmWeightUses(result))
continue;
diagnostics.report(op, [&](Operation* illegalOp) {
illegalOp->emitOpError(
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
});
return;
}
}
});
}
} // namespace
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
@@ -38,9 +47,7 @@ LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
});
}
checkWeightsDirectlyExtracted(funcOp, diagnostics);
checkWeightUseChains(funcOp, diagnostics);
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
return success(!diagnostics.hasFailure());
File diff suppressed because it is too large Load Diff
@@ -67,7 +67,7 @@ static CoalescingReportRow getTotalRow(const CoalescingReportEntry& entry) {
}
static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
std::fstream file = openReportFile("static_memory_coalescing_report");
std::fstream file = openReportFile("memory_coalescing_report");
if (!file.is_open())
return;
+2
View File
@@ -49,6 +49,7 @@ def SpatCompute : SpatOp<"compute",
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
@@ -84,6 +85,7 @@ def SpatComputeBatch : SpatOp<"compute_batch",
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
}];
+17 -3
View File
@@ -1,5 +1,6 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
#include <string>
@@ -35,6 +36,17 @@ void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t i
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
}
using CrossbarWeightSet = llvm::SetVector<Value, llvm::SmallVector<Value, 4>, llvm::SmallDenseSet<Value, 4>>;
CrossbarWeightSet collectCrossbarWeights(Region& body) {
CrossbarWeightSet weights;
body.walk([&](SpatVMMOp vmmOp) {
Value weight = vmmOp.getWeight();
weights.insert(weight);
});
return weights;
}
} // namespace
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); }
@@ -45,7 +57,6 @@ std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
llvm::dbgs() << "Disse netanyao\n";
auto index = std::distance(getWeights().begin(), existing);
return {
{*existing, *getWeightArgument(index)}
@@ -75,6 +86,8 @@ std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigne
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
}
CrossbarWeightSet SpatCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, SpatCompute>>
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > getNumResults())
@@ -127,7 +140,6 @@ std::optional<std::tuple<Value, BlockArgument>>
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
auto index = std::distance(getWeights().begin(), existing);
llvm::dbgs() << "Bum bum bum bum\n";
return {
{*existing, *getWeightArgument(index)}
};
@@ -156,6 +168,8 @@ std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(un
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
}
CrossbarWeightSet SpatComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
if (idx > getNumResults())
+3
View File
@@ -10,6 +10,9 @@
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
#include <map>
#include <optional>
#include <string>
+15
View File
@@ -10,6 +10,7 @@
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp"
using namespace mlir;
@@ -239,6 +240,7 @@ void SpatCompute::print(OpAsmPrinter& printer) {
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
printer << " coreId " << coreIdAttr.getInt();
@@ -264,6 +266,7 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
int32_t coreId = 0;
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
@@ -273,9 +276,14 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
if (hasCoreId && parser.parseInteger(coreId))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
@@ -357,6 +365,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
if (getNumResults() != 0) {
printer << " shared_outs";
@@ -395,6 +404,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
int32_t crossbarWeightCount = 0;
SmallVector<int32_t> coreIds;
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
@@ -413,9 +423,14 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
if (parseBlockArgumentList(parser, outputArgs))
return failure();
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
return failure();
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
(void) crossbarWeightCount;
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
+5 -10
View File
@@ -50,10 +50,9 @@ static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
template <typename ComputeOpTy>
static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind) {
for (Value weight : computeOp.getWeights()) {
for (Value weight : computeOp.getWeights())
if (!isCompileTimeComputable(weight))
return computeOp.emitOpError() << kind << " weights must be statically computed from constants";
}
return success();
}
@@ -131,11 +130,9 @@ verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgume
return sliceOp.emitOpError() << kind << " requires static slice sizes";
auto offsets = sliceOp.getOffsets();
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
if (!supported)
for (Value offset : offsets)
if (!isSupportedLaneOffsetExpr(offset, laneArg))
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
}
return success();
}
@@ -155,11 +152,9 @@ static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::Paralle
return sliceOp.emitOpError() << kind << " requires static slice sizes";
auto offsets = sliceOp.getOffsets();
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
if (!supported)
for (Value offset : offsets)
if (!isSupportedLaneOffsetExpr(offset, laneArg))
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
}
return success();
}
@@ -1,8 +1,6 @@
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
@@ -34,6 +32,7 @@
#include <vector>
#include "MaterializeMergeSchedule.hpp"
#include "Scheduling/ComputeGraph.hpp"
#include "Scheduling/ComputeInstanceUtils.hpp"
#include "Scheduling/MergeSchedulingAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
@@ -282,11 +281,8 @@ void emitMotifProfile(func::FuncOp funcOp) {
for (auto [index, compute] : llvm::enumerate(computes)) {
ComputeMotifInfo& info = computeInfos[index];
for (Operation& op : compute.getBody().front()) {
info.instructionCount++;
if (isa<spatial::SpatVMMOp>(&op))
info.weightedVmmCount++;
}
info.instructionCount = spatial::countComputeBodyInstructions(compute.getBody());
compute.getBody().walk([&](spatial::SpatVMMOp) { info.weightedVmmCount++; });
if (info.weightedVmmCount > 0) {
weightedVmmNodeCount++;
weightedVmmOpCount += info.weightedVmmCount;
@@ -480,7 +476,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
struct ReportRow {
uint64_t id = 0;
uint64_t logicalComputeCount = 0;
uint64_t weightCount = 0;
uint64_t crossbarCount = 0;
uint64_t instructionCount = 0;
bool isRebatched = false;
SmallVector<int32_t> coreIds;
@@ -490,38 +486,40 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
uint64_t totalLogicalComputes = 0;
uint64_t totalBatchComputeOps = 0;
uint64_t totalInstructionCount = 0;
uint64_t totalWeightCount = 0;
uint64_t totalCrossbarCount = 0;
uint64_t nextBatchId = 0;
std::vector<ReportRow> collectedData;
auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t {
return static_cast<uint64_t>(spatial::collectDistinctCrossbarWeights(op).size());
};
for (Operation& op : funcOp.getBody().front()) {
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
uint64_t numInst = 0;
for (auto& _ : spatCompute.getRegion().front())
++numInst;
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
SmallVector<int32_t> coreIds;
if (auto coreId = getComputeCoreId(spatCompute))
coreIds.push_back(*coreId);
collectedData.push_back({totalComputeOps++, 1, spatCompute.getWeights().size(), numInst, false, coreIds});
collectedData.push_back({totalComputeOps++, 1, perInstanceCrossbarCount, numInst, false, coreIds});
totalLogicalComputes += 1;
totalInstructionCount += numInst;
totalWeightCount += spatCompute.getWeights().size();
totalCrossbarCount += perInstanceCrossbarCount;
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
uint64_t numInst = 0;
for (auto& _ : batch.getRegion().front())
++numInst;
uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody());
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
SmallVector<int32_t> coreIds;
if (auto coreIdsAttr = batch->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
llvm::append_range(coreIds, coreIdsAttr.asArrayRef());
collectedData.push_back({nextBatchId++, logicalCount, batch.getWeights().size(), numInst, true, coreIds});
collectedData.push_back({nextBatchId++, logicalCount, perInstanceCrossbarCount * logicalCount, numInst, true, coreIds});
totalComputeOps += 1;
totalLogicalComputes += logicalCount;
totalBatchComputeOps += 1;
totalInstructionCount += numInst * logicalCount;
totalWeightCount += batch.getWeights().size();
totalCrossbarCount += perInstanceCrossbarCount * logicalCount;
}
}
@@ -531,7 +529,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
{"Number of logical computes", std::to_string(totalLogicalComputes) },
{"Number of top-level batch compute ops", std::to_string(totalBatchComputeOps) },
{"Number of instructions", std::to_string(totalInstructionCount)},
{"Number of used crossbars", std::to_string(totalWeightCount) }
{"Number of used crossbars", std::to_string(totalCrossbarCount) }
};
printReportTotalsBlock(os, totalFields);
if (!collectedData.empty())
@@ -545,7 +543,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
for (uint64_t nI = cI + 1; nI < totalComputeOps; ++nI) {
ReportRow next = collectedData[nI];
if (current.isRebatched == next.isRebatched && current.weightCount == next.weightCount
if (current.isRebatched == next.isRebatched && current.crossbarCount == next.crossbarCount
&& current.instructionCount == next.instructionCount
&& current.logicalComputeCount == next.logicalComputeCount)
lastIndex = nI;
@@ -578,20 +576,20 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
os << ":\n";
uint64_t perCoreLogicalComputeCount = current.isRebatched ? 1 : current.logicalComputeCount;
uint64_t perCoreInstructionCount = current.instructionCount;
uint64_t perCoreWeightCount =
current.logicalComputeCount == 0 ? 0 : current.weightCount / current.logicalComputeCount;
uint64_t perCoreCrossbarCount =
current.logicalComputeCount == 0 ? 0 : current.crossbarCount / current.logicalComputeCount;
uint64_t totalEntryInstructionCount = current.instructionCount * current.logicalComputeCount;
llvm::SmallVector<ReportField, 3> perCoreFields = {
{"Number of logical computes", std::to_string(perCoreLogicalComputeCount)},
{"Number of instructions", std::to_string(perCoreInstructionCount) },
{"Number of used crossbars", std::to_string(perCoreWeightCount) }
{"Number of used crossbars", std::to_string(perCoreCrossbarCount) }
};
if (current.isRebatched) {
llvm::SmallVector<ReportField, 3> totalEntryFields = {
{"Number of logical computes", std::to_string(current.logicalComputeCount)},
{"Number of instructions", std::to_string(totalEntryInstructionCount) },
{"Number of used crossbars", std::to_string(current.weightCount) }
{"Number of used crossbars", std::to_string(current.crossbarCount) }
};
printReportPerCoreAndTotalFields(os, perCoreFields, totalEntryFields);
}
@@ -655,7 +653,7 @@ public:
}
emitMergeIrCounts("final-post-merge", func);
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
generateReport(func, "dcp_merge_report", analysisResult->cpuToLastComputeMap.size());
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
}
}
};
@@ -1,4 +1,8 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Unit.h"
@@ -9,12 +13,14 @@
#include <algorithm>
#include <iterator>
#include <limits>
#include <optional>
#include <queue>
#include <utility>
#include <vector>
#include "ComputeGraph.hpp"
#include "ComputeInstanceUtils.hpp"
#include "src/Support/TypeUtilities.hpp"
namespace onnx_mlir {
@@ -22,15 +28,42 @@ namespace spatial {
using namespace mlir;
uint64_t countComputeBodyInstructions(Region& body);
uint64_t countComputeBodyOperationInstances(Region& body);
namespace {
Weight getComputeBodyWeight(Region& body) {
constexpr Weight kOperationWeight = 100;
Weight numOperations = 0;
for (auto& block : body)
for ([[maybe_unused]] auto& op : block)
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
return checkedMultiply(numOperations, kOperationWeight);
Cost getComputeBodyCost(Region& body) {
constexpr Cost kOperationCost = 100;
return checkedMultiply(static_cast<Cost>(countComputeBodyOperationInstances(body)), kOperationCost);
}
std::optional<uint64_t> getStaticTripCount(scf::ForOp loop) {
auto lb = getConstantIntValue(loop.getLowerBound());
auto ub = getConstantIntValue(loop.getUpperBound());
auto step = getConstantIntValue(loop.getStep());
if (!lb || !ub || !step || *step <= 0)
return std::nullopt;
if (*ub <= *lb)
return 0;
uint64_t distance = static_cast<uint64_t>(*ub - *lb);
uint64_t stride = static_cast<uint64_t>(*step);
return (distance + stride - 1) / stride;
}
uint64_t countOperationInstances(Operation& op) {
if (auto loop = dyn_cast<scf::ForOp>(&op)) {
std::optional<uint64_t> tripCount = getStaticTripCount(loop);
if (!tripCount)
return 1;
return checkedMultiply(countComputeBodyOperationInstances(loop.getRegion()), *tripCount);
}
uint64_t instances = 1;
for (Region& region : op.getRegions())
instances = checkedAdd(instances, countComputeBodyOperationInstances(region));
return instances;
}
bool isUsedAsWeightOnly(Operation* producerOp) {
@@ -61,7 +94,7 @@ bool isLaneOffset(OpFoldResult offset, Value laneArg) {
return offsetValue == laneArg;
}
std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
auto inputIt = llvm::find(batch.getInputs(), input);
if (inputIt == batch.getInputs().end())
return std::nullopt;
@@ -72,7 +105,7 @@ std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch,
if (!inputArg || !laneArg)
return std::nullopt;
Weight projectedCost = 0;
Cost projectedCost = 0;
for (Operation* user : inputArg->getUsers()) {
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
if (!extract || extract.getSource() != *inputArg)
@@ -83,7 +116,7 @@ std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch,
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
if (!resultType || !resultType.hasStaticShape())
return std::nullopt;
projectedCost = checkedAdd(projectedCost, static_cast<Weight>(getSizeInBytes(resultType)));
projectedCost = checkedAdd(projectedCost, static_cast<Cost>(getSizeInBytes(resultType)));
}
if (projectedCost == 0)
@@ -91,28 +124,286 @@ std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch,
return projectedCost;
}
Weight getInputTransferCost(const ComputeInstance& consumerInstance, Value input) {
Cost getInputTransferCost(const ComputeInstance& consumerInstance, Value input) {
auto inputType = cast<ShapedType>(input.getType());
if (auto batch = dyn_cast<SpatComputeBatch>(consumerInstance.op))
if (std::optional<Weight> projectedCost = getBatchProjectedInputTransferCost(batch, input))
if (std::optional<Cost> projectedCost = getBatchProjectedInputTransferCost(batch, input))
return *projectedCost;
return static_cast<Weight>(getSizeInBytes(inputType));
return static_cast<Cost>(getSizeInBytes(inputType));
}
static CrossbarWeight getOpaqueCrossbarWeight(Value value, std::optional<uint32_t> lane) {
CrossbarWeight weight;
weight.opaqueValue = value;
weight.opaqueLane = lane.value_or(std::numeric_limits<uint32_t>::max());
return weight;
}
static FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
return constant.getValue();
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
unsigned position = dim.getPosition();
if (position >= dims.size())
return failure();
return dims[position];
}
if (auto symbol = dyn_cast<AffineSymbolExpr>(expr)) {
unsigned position = symbol.getPosition();
if (position >= symbols.size())
return failure();
return symbols[position];
}
auto binary = dyn_cast<AffineBinaryOpExpr>(expr);
if (!binary)
return failure();
FailureOr<int64_t> lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols);
FailureOr<int64_t> rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols);
if (failed(lhs) || failed(rhs))
return failure();
auto floorDiv = [](int64_t value, int64_t divisor) -> FailureOr<int64_t> {
if (divisor <= 0)
return failure();
if (value >= 0)
return value / divisor;
return -((-value + divisor - 1) / divisor);
};
switch (binary.getKind()) {
case AffineExprKind::Add: return *lhs + *rhs;
case AffineExprKind::Mul: return *lhs * *rhs;
case AffineExprKind::FloorDiv: return floorDiv(*lhs, *rhs);
case AffineExprKind::CeilDiv:
if (*rhs <= 0)
return failure();
return (*lhs + *rhs - 1) / *rhs;
case AffineExprKind::Mod: {
FailureOr<int64_t> div = floorDiv(*lhs, *rhs);
if (failed(div))
return failure();
return *lhs - *div * *rhs;
}
default: return failure();
}
}
static FailureOr<int64_t>
evaluateIndexLike(Value value, const DenseMap<Value, int64_t>& bindings, std::optional<uint32_t> lane, Value laneArg);
static FailureOr<int64_t> evaluateIndexLike(OpFoldResult value,
const DenseMap<Value, int64_t>& bindings,
std::optional<uint32_t> lane,
Value laneArg) {
if (auto attr = llvm::dyn_cast<Attribute>(value)) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
if (!intAttr)
return failure();
return intAttr.getInt();
}
return evaluateIndexLike(llvm::cast<Value>(value), bindings, lane, laneArg);
}
static FailureOr<int64_t> evaluateIndexLike(Value value,
const DenseMap<Value, int64_t>& bindings,
std::optional<uint32_t> lane,
Value laneArg) {
if (lane && value == laneArg)
return *lane;
if (auto it = bindings.find(value); it != bindings.end())
return it->second;
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
return constant.value();
if (auto constant = value.getDefiningOp<arith::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()))
return intAttr.getInt();
if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) {
auto constant = extract.getTensor().getDefiningOp<arith::ConstantOp>();
auto elements = constant ? dyn_cast<ElementsAttr>(constant.getValue()) : nullptr;
auto shapedType = elements ? dyn_cast<ShapedType>(elements.getType()) : nullptr;
if (!elements || !shapedType || shapedType.getRank() != 1 || extract.getIndices().size() != 1)
return failure();
FailureOr<int64_t> index = evaluateIndexLike(extract.getIndices().front(), bindings, lane, laneArg);
if (failed(index) || *index < 0 || *index >= static_cast<int64_t>(elements.getNumElements()))
return failure();
if (auto denseInts = dyn_cast<DenseIntElementsAttr>(elements))
return (*(denseInts.value_begin<APInt>() + *index)).getSExtValue();
return failure();
}
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
if (!affineApply)
return failure();
AffineMap map = affineApply.getAffineMap();
if (map.getNumResults() != 1)
return failure();
SmallVector<int64_t, 4> operands;
operands.reserve(affineApply.getMapOperands().size());
for (Value operand : affineApply.getMapOperands()) {
FailureOr<int64_t> folded = evaluateIndexLike(operand, bindings, lane, laneArg);
if (failed(folded))
return failure();
operands.push_back(*folded);
}
ArrayRef<int64_t> dims(operands.data(), map.getNumDims());
ArrayRef<int64_t> symbols(operands.data() + map.getNumDims(), map.getNumSymbols());
return evaluateAffineExpr(map.getResult(0), dims, symbols);
}
static FailureOr<SmallVector<int64_t, 4>>
evaluateIndexList(ArrayRef<OpFoldResult> values,
const DenseMap<Value, int64_t>& bindings,
std::optional<uint32_t> lane,
Value laneArg) {
SmallVector<int64_t, 4> result;
result.reserve(values.size());
for (OpFoldResult value : values) {
FailureOr<int64_t> folded = evaluateIndexLike(value, bindings, lane, laneArg);
if (failed(folded))
return failure();
result.push_back(*folded);
}
return result;
}
static Value resolveCrossbarWeightRoot(Operation* owner, Value root) {
if (auto arg = dyn_cast<BlockArgument>(root)) {
if (auto compute = dyn_cast<SpatCompute>(owner)) {
for (auto [index, operand] : llvm::enumerate(compute.getWeights()))
if (compute.getWeightArgument(index) == arg)
return operand;
}
if (auto batch = dyn_cast<SpatComputeBatch>(owner)) {
for (auto [index, operand] : llvm::enumerate(batch.getWeights()))
if (batch.getWeightArgument(index) == arg)
return operand;
}
}
return root;
}
static CrossbarWeight completeCrossbarWeight(Value root,
SmallVector<int64_t, 4> offsets,
SmallVector<int64_t, 4> sizes,
SmallVector<int64_t, 4> strides) {
CrossbarWeight weight;
weight.root = root;
if (auto constant = root.getDefiningOp<arith::ConstantOp>())
weight.rootAttr = static_cast<Attribute>(constant.getValue());
weight.offsets = std::move(offsets);
weight.sizes = std::move(sizes);
weight.strides = std::move(strides);
return weight;
}
static FailureOr<CrossbarWeight>
getStaticCrossbarWeight(Operation* owner,
Value value,
const DenseMap<Value, int64_t>& bindings,
std::optional<uint32_t> lane,
Value laneArg) {
if (auto extract = value.getDefiningOp<tensor::ExtractSliceOp>()) {
FailureOr<CrossbarWeight> sourceWeight =
getStaticCrossbarWeight(owner, extract.getSource(), bindings, lane, laneArg);
auto offsets = evaluateIndexList(extract.getMixedOffsets(), bindings, lane, laneArg);
auto sizes = evaluateIndexList(extract.getMixedSizes(), bindings, lane, laneArg);
auto strides = evaluateIndexList(extract.getMixedStrides(), bindings, lane, laneArg);
if (failed(sourceWeight) || failed(offsets) || failed(sizes) || failed(strides))
return failure();
if (sourceWeight->offsets.size() != offsets->size() || sourceWeight->sizes.size() != sizes->size()
|| sourceWeight->strides.size() != strides->size()) {
return failure();
}
for (auto [index, offset] : llvm::enumerate(*offsets)) {
sourceWeight->offsets[index] += offset * sourceWeight->strides[index];
sourceWeight->sizes[index] = (*sizes)[index];
sourceWeight->strides[index] *= (*strides)[index];
}
return *sourceWeight;
}
Value root = resolveCrossbarWeightRoot(owner, value);
auto type = dyn_cast<ShapedType>(root.getType());
if (!type || !type.hasStaticShape())
return failure();
SmallVector<int64_t, 4> offsets(type.getRank(), 0);
SmallVector<int64_t, 4> sizes(type.getShape().begin(), type.getShape().end());
SmallVector<int64_t, 4> strides(type.getRank(), 1);
return completeCrossbarWeight(root, std::move(offsets), std::move(sizes), std::move(strides));
}
static void addCrossbarWeight(CrossbarUsage& usage, CrossbarWeight weight) {
if (!containsCrossbarWeight(usage, weight))
usage.push_back(std::move(weight));
}
static void collectCrossbarWeightsFromOp(Operation* op,
Operation* owner,
DenseMap<Value, int64_t>& bindings,
CrossbarUsage& usage,
Value laneArg,
std::optional<uint32_t> lane) {
if (auto loop = dyn_cast<scf::ForOp>(op)) {
auto lb = getConstantIntValue(loop.getLowerBound());
auto ub = getConstantIntValue(loop.getUpperBound());
auto step = getConstantIntValue(loop.getStep());
if (!lb || !ub || !step || *step <= 0)
return;
for (int64_t iv = *lb; iv < *ub; iv += *step) {
bindings[loop.getInductionVar()] = iv;
for (Operation& nested : loop.getBody()->without_terminator())
collectCrossbarWeightsFromOp(&nested, owner, bindings, usage, laneArg, lane);
}
bindings.erase(loop.getInductionVar());
return;
}
if (auto vmm = dyn_cast<SpatVMMOp>(op)) {
FailureOr<CrossbarWeight> weight = getStaticCrossbarWeight(owner, vmm.getWeight(), bindings, lane, laneArg);
if (failed(weight)) {
addCrossbarWeight(usage, getOpaqueCrossbarWeight(vmm.getWeight(), lane));
return;
}
addCrossbarWeight(usage, *weight);
return;
}
for (Region& region : op->getRegions())
for (Block& block : region)
for (Operation& nested : block.without_terminator())
collectCrossbarWeightsFromOp(&nested, owner, bindings, usage, laneArg, lane);
}
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
llvm::DenseMap<std::pair<size_t, size_t>, Cost> edgeCosts;
for (const ComputeGraphEdge& edge : edges) {
if (edge.source == edge.target)
continue;
auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost);
auto inserted = edgeCosts.try_emplace({edge.source, edge.target}, edge.transferCost);
if (!inserted.second)
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
}
std::vector<ComputeGraphEdge> aggregatedEdges;
aggregatedEdges.reserve(edgeWeights.size());
for (const auto& [key, weight] : edgeWeights)
aggregatedEdges.push_back({key.first, key.second, weight});
aggregatedEdges.reserve(edgeCosts.size());
for (const auto& [key, cost] : edgeCosts)
aggregatedEdges.push_back({key.first, key.second, cost});
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge& lhs, const ComputeGraphEdge& rhs) {
if (lhs.source != rhs.source)
return lhs.source < rhs.source;
@@ -123,30 +414,75 @@ std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> ed
} // namespace
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
uint64_t countComputeBodyInstructions(Region& body) {
uint64_t numOperations = 0;
body.walk([&](Operation* op) { numOperations = checkedAdd(numOperations, static_cast<uint64_t>(1)); });
return numOperations;
}
uint64_t countComputeBodyOperationInstances(Region& body) {
uint64_t instances = 0;
for (Block& block : body)
for (Operation& op : block)
instances = checkedAdd(instances, countOperationInstances(op));
return instances;
}
CrossbarUsage collectDistinctCrossbarWeights(Operation* owner, std::optional<uint32_t> lane) {
CrossbarUsage usage;
DenseMap<Value, int64_t> bindings;
Value laneArg;
if (auto batch = dyn_cast<SpatComputeBatch>(owner))
if (auto maybeLaneArg = batch.getLaneArgument())
laneArg = *maybeLaneArg;
for (Region& region : owner->getRegions())
for (Block& block : region)
for (Operation& op : block.without_terminator())
collectCrossbarWeightsFromOp(&op, owner, bindings, usage, laneArg, lane);
return usage;
}
Cost getComputeInstanceCost(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getComputeBodyWeight(spatCompute.getBody());
return getComputeBodyCost(spatCompute.getBody());
auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
return checkedMultiply(getComputeBodyCost(batch.getBody()), static_cast<Cost>(instance.laneCount));
}
CrossbarUsage getSpatComputeCrossbarUsage(SpatCompute spatComute){
CrossbarUsage ret;
ret.insert_range(spatComute.getWeights());
return ret;
bool containsCrossbarWeight(ArrayRef<CrossbarWeight> usage, const CrossbarWeight& weight) {
return llvm::is_contained(usage, weight);
}
CrossbarUsage getSpatComputeBatchCrossbarUsage(SpatComputeBatch spatComuteBatch){
CrossbarUsage ret;
ret.insert_range(spatComuteBatch.getWeights());
return ret;
unsigned countCrossbarOverlap(ArrayRef<CrossbarWeight> lhs, ArrayRef<CrossbarWeight> rhs) {
unsigned overlap = 0;
for (const CrossbarWeight& weight : rhs)
if (containsCrossbarWeight(lhs, weight))
++overlap;
return overlap;
}
size_t getCrossbarUnionSize(ArrayRef<CrossbarWeight> lhs, ArrayRef<CrossbarWeight> rhs) {
size_t size = lhs.size();
for (const CrossbarWeight& weight : rhs)
if (!containsCrossbarWeight(lhs, weight))
++size;
return size;
}
void insertCrossbarWeights(CrossbarUsage& usage, ArrayRef<CrossbarWeight> weights) {
for (const CrossbarWeight& weight : weights)
addCrossbarWeight(usage, weight);
}
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeCrossbarUsage(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op);
return getSpatComputeBatchCrossbarUsage(batch);
CrossbarUsage usage;
if (isa<SpatCompute>(instance.op))
return collectDistinctCrossbarWeights(instance.op);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
insertCrossbarWeights(usage, collectDistinctCrossbarWeights(instance.op, lane));
return usage;
}
ComputeGraph buildComputeGraph(Operation* entryOp) {
@@ -161,7 +497,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
size_t index = graph.nodes.size();
graph.nodes.push_back(
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
{instance, getComputeInstanceCost(instance), getComputeInstanceCrossbarUsage(instance), index});
graph.instanceToIndex[instance] = index;
continue;
}
@@ -173,7 +509,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
ComputeInstance instance = getBatchChunkForIndex(batch, chunkIndex);
size_t index = graph.nodes.size();
graph.nodes.push_back(
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
{instance, getComputeInstanceCost(instance), getComputeInstanceCrossbarUsage(instance), index});
graph.instanceToIndex[instance] = index;
}
}
@@ -185,7 +521,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
llvm::SmallVector<Value, 4> inputs = getComputeInstanceInputs(node.instance);
for (Value input : inputs) {
Weight transferCost = getInputTransferCost(node.instance, input);
Cost transferCost = getInputTransferCost(node.instance, input);
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
@@ -208,7 +544,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
}
std::vector<ComputeGraphEdge> aggregatedEdges = aggregateEdges(rawEdges);
graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end());
graph.edges.insert(graph.edges.end(), aggregatedEdges.begin(), aggregatedEdges.end());
graph.successors.assign(graph.nodes.size(), {});
graph.predecessors.assign(graph.nodes.size(), {});
for (const ComputeGraphEdge& edge : graph.edges) {
@@ -233,8 +569,8 @@ bool verifyAcyclic(const ComputeGraph& graph) {
size_t node = readyNodes.front();
readyNodes.pop();
++visited;
for (const auto& [child, weight] : graph.successors[node]) {
(void) weight;
for (const auto& [child, cost] : graph.successors[node]) {
(void) cost;
assert(remainingParents[child] > 0 && "remaining parent count underflow");
if (--remainingParents[child] == 0)
readyNodes.push(child);
@@ -1,52 +1,72 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <optional>
#include <utility>
#include <vector>
#include "Utils.hpp"
#include "ComputeInstance.hpp"
#include "ComputeInstanceUtils.hpp"
#include "Utils.hpp"
using CrossbarUsage = llvm::SmallPtrSet<mlir::Value, 6>;
struct CrossbarWeight {
mlir::Value root;
mlir::Attribute rootAttr;
llvm::SmallVector<int64_t, 4> offsets;
llvm::SmallVector<int64_t, 4> sizes;
llvm::SmallVector<int64_t, 4> strides;
mlir::Value opaqueValue;
uint32_t opaqueLane = 0;
bool operator==(const CrossbarWeight& other) const {
bool sameRoot = rootAttr && other.rootAttr ? rootAttr == other.rootAttr : root == other.root;
return sameRoot && offsets == other.offsets && sizes == other.sizes && strides == other.strides
&& opaqueValue == other.opaqueValue && opaqueLane == other.opaqueLane;
}
};
using CrossbarUsage = llvm::SmallVector<CrossbarWeight, 6>;
namespace onnx_mlir {
namespace spatial {
struct ComputeGraphNode {
ComputeInstance instance;
Weight weight = 0;
llvm::SmallPtrSet<mlir::Value,6> crossbarUsage;
Cost cost = 0;
CrossbarUsage crossbarUsage;
size_t originalOrder = 0;
};
struct ComputeGraphEdge {
size_t source = 0;
size_t target = 0;
Weight transferCost = 0;
Cost transferCost = 0;
};
struct ComputeGraph {
llvm::SmallVector<ComputeGraphNode> nodes;
llvm::SmallVector<ComputeGraphEdge> edges;
std::vector<std::vector<std::pair<size_t, Weight>>> successors;
std::vector<std::vector<std::pair<size_t, Weight>>> predecessors;
std::vector<ComputeGraphNode> nodes;
std::vector<ComputeGraphEdge> edges;
std::vector<std::vector<std::pair<size_t, Cost>>> successors;
std::vector<std::vector<std::pair<size_t, Cost>>> predecessors;
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
};
ComputeGraph buildComputeGraph(mlir::Operation* entryOp);
bool verifyAcyclic(const ComputeGraph& graph);
Weight getComputeInstanceWeight(const ComputeInstance& instance);
uint64_t countComputeBodyInstructions(mlir::Region& body);
uint64_t countComputeBodyOperationInstances(mlir::Region& body);
Cost getComputeInstanceCost(const ComputeInstance& instance);
CrossbarUsage collectDistinctCrossbarWeights(mlir::Operation* owner, std::optional<uint32_t> lane = std::nullopt);
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance);
bool containsCrossbarWeight(llvm::ArrayRef<CrossbarWeight> usage, const CrossbarWeight& weight);
unsigned countCrossbarOverlap(llvm::ArrayRef<CrossbarWeight> lhs, llvm::ArrayRef<CrossbarWeight> rhs);
size_t getCrossbarUnionSize(llvm::ArrayRef<CrossbarWeight> lhs, llvm::ArrayRef<CrossbarWeight> rhs);
void insertCrossbarWeights(CrossbarUsage& usage, llvm::ArrayRef<CrossbarWeight> weights);
} // namespace spatial
} // namespace onnx_mlir
@@ -48,12 +48,12 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result
return lhs.second < rhs.second;
});
unsigned int usedCrossbars = 0;
CrossbarUsage usedCrossbars;
for (size_t slot = 0; slot < scheduledTasks.size(); ++slot) {
if (scheduledTasks[slot].first != slot)
llvm::report_fatal_error("merge scheduling: CPU slots are not contiguous");
usedCrossbars = addOrMax(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage.size());
if (usedCrossbars > crossbarCapacity)
insertCrossbarWeights(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage);
if (usedCrossbars.size() > crossbarCapacity)
llvm::report_fatal_error("merge scheduling: CPU crossbar capacity exceeded");
}
@@ -77,7 +77,7 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result
if (sourceCpu == targetCpu && sourceSlot >= targetSlot)
llvm::report_fatal_error("merge scheduling: same-CPU dependency order is invalid");
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].weight);
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].cost);
if (sourceCpu != targetCpu)
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
if (targetStart < earliestTargetStart) {
@@ -89,8 +89,8 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
// MOCK: Replace this with your actual heterogeneous cost lookup.
// If graph.nodes[task] is modified to hold a vector of weights per processor, access it here.
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].weight; };
// If graph.nodes[task] is modified to hold a vector of costs per processor, access it here.
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].cost; };
std::vector<Time> oct(nodeCount * processorCount, 0);
std::vector<Time> minOctPlusComp(nodeCount, 0);
@@ -163,7 +163,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
}
std::vector<char> scheduled(nodeCount, false);
std::vector<CrossbarUsage> processorCrossbars(processorCount, llvm::SmallPtrSet<mlir::Value, 6> {});
std::vector<CrossbarUsage> processorCrossbars(processorCount);
std::vector<ScheduledTask> schedules(nodeCount);
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
@@ -178,25 +178,17 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
Time bestEst = 0;
Time bestEft = 0;
Time bestOeft = std::numeric_limits<Time>::max();
unsigned int bestOverlapWeight = 0;
unsigned int bestOverlapCount = 0;
bool crossbarRejected = false;
auto crossbarsAreContainedInProcessor = [&processorCrossbars](mlir::Value nodeCrossbar, size_t processor) {
return llvm::is_contained(processorCrossbars[processor], nodeCrossbar);
};
for (size_t processor = 0; processor < processorCount; ++processor) {
auto crossbarsAreContained = std::bind(crossbarsAreContainedInProcessor, std::placeholders::_1, processor);
if (graph.nodes[task].crossbarUsage.size() != 0
&& !llvm::all_of(graph.nodes[task].crossbarUsage, crossbarsAreContained)
&& addOrMax(processorCrossbars[processor].size(), graph.nodes[task].crossbarUsage.size())
unsigned int overlapCount = countCrossbarOverlap(processorCrossbars[processor], graph.nodes[task].crossbarUsage);
if (!graph.nodes[task].crossbarUsage.empty()
&& getCrossbarUnionSize(processorCrossbars[processor], graph.nodes[task].crossbarUsage)
> options.crossbarCapacity) {
crossbarRejected = true;
continue;
}
unsigned int overlapWeight =
llvm::count_if(graph.nodes[task].crossbarUsage, crossbarsAreContained);
crossbarRejected = true;
continue;
}
Time dataReady = 0;
for (const auto& [pred, comm] : graph.predecessors[task]) {
@@ -206,7 +198,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
}
// 2. PEFT Gap-Filling EST Calculation (Maintains optimal scheduling math)
Time compWeight = getComputeCost(task, processor);
Time computeCost = getComputeCost(task, processor);
Time est = dataReady;
Time currentEnd = 0;
bool foundGap = false;
@@ -215,7 +207,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
const ScheduledTask& schedTask = schedules[schedTaskIndex];
Time gapStart = std::max(currentEnd, dataReady);
if (addOrMax(gapStart, compWeight) <= schedTask.startTime) {
if (addOrMax(gapStart, computeCost) <= schedTask.startTime) {
est = gapStart;
foundGap = true;
break;
@@ -226,7 +218,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
if (!foundGap)
est = std::max(currentEnd, dataReady);
Time eft = addOrMax(est, compWeight);
Time eft = addOrMax(est, computeCost);
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
@@ -235,14 +227,14 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
bestEst = est;
bestEft = eft;
bestOeft = oeft;
bestOverlapWeight = overlapWeight;
bestOverlapCount = overlapCount;
}
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapWeight < bestOverlapWeight) {
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapCount < bestOverlapCount) {
bestProcessor = processor;
bestEst = est;
bestEft = eft;
bestOeft = oeft;
bestOverlapWeight = overlapWeight;
bestOverlapCount = overlapCount;
}
}
@@ -265,7 +257,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
schedules[task] = {bestProcessor, bestEst, bestEft};
scheduled[task] = true;
++scheduledCount;
processorCrossbars[bestProcessor].insert_range(graph.nodes[task].crossbarUsage);
insertCrossbarWeights(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
// 3. CRITICAL FIX: Topological Append
// Because the readyQueue pops in strict topological order, simply pushing to the
@@ -319,37 +311,6 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
}
}
}
/*{
llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n";
std::vector<bool> visited(processorCount, false);
size_t uniqueClassCount = 0;
for (size_t i = 0; i < processorCount; ++i) {
if (visited[i])
continue;
// We found a new unique schedule (equivalence class)
++uniqueClassCount;
visited[i] = true;
llvm::dbgs() << "Class " << uniqueClassCount << ": CPUs { " << i;
// Find and mark all identical companions
auto it = equivalentClass.find(i);
if (it != equivalentClass.end()) {
for (size_t eqCpu : it->second) {
if (!visited[eqCpu]) {
llvm::dbgs() << ", " << eqCpu;
visited[eqCpu] = true;
}
}
}
llvm::dbgs() << " }\n";
}
llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n";
llvm::dbgs() << "--------------------------------------\n";
}*/
// 6. Populate Final Result
MergeScheduleResult result;
@@ -16,7 +16,7 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using CPU = int;
using Weight = unsigned long long;
using Cost = unsigned long long;
using Time = unsigned long long;
@@ -55,4 +55,3 @@ inline T subtractOrZero(T lhs, T rhs) {
}
inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); }