This commit is contained in:
@@ -40,6 +40,21 @@ Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, O
|
||||
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
||||
}
|
||||
|
||||
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, RewriterBase& rewriter) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
Block* hostBlock = getHostConstantBlock(anchorOp);
|
||||
for (Operation& op : *hostBlock) {
|
||||
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
||||
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
||||
continue;
|
||||
return constantOp.getResult();
|
||||
}
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(hostBlock);
|
||||
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
|
||||
}
|
||||
|
||||
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
|
||||
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
|
||||
}
|
||||
@@ -49,6 +64,11 @@ Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, Operation
|
||||
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
|
||||
}
|
||||
|
||||
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, RewriterBase& rewriter) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), rewriter);
|
||||
}
|
||||
|
||||
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
@@ -17,10 +14,17 @@ mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
|
||||
mlir::Type type,
|
||||
mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
|
||||
mlir::Attribute value,
|
||||
mlir::Type type,
|
||||
mlir::RewriterBase& rewriter);
|
||||
|
||||
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::RewriterBase& rewriter);
|
||||
|
||||
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
@@ -12,13 +13,12 @@
|
||||
|
||||
#include "Common/Common.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -118,6 +118,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
preTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
affine::AffineDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
|
||||
@@ -156,6 +157,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
target.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
affine::AffineDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
target.addIllegalOp<ONNXMatMulOp>();
|
||||
@@ -189,6 +191,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
affine::AffineDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
|
||||
@@ -203,6 +206,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
postTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
affine::AffineDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "Common/IR/WeightUtils.hpp"
|
||||
@@ -13,17 +11,28 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void checkWeightsDirectlyExtracted(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
for (auto extractSlice : func.getOps<tensor::ExtractSliceOp>()) {
|
||||
auto source = getCompileTimeSource(extractSlice.getOperation());
|
||||
if (source && hasWeightAlways(source->source) && source->chainLength > 1) {
|
||||
namespace {
|
||||
|
||||
diagnostics.report(extractSlice.getOperation(),
|
||||
[](Operation* illegalOp) { illegalOp->emitOpError("Weight not directly extracted"); });
|
||||
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
func.walk([&](Operation* op) {
|
||||
if (!hasWeightAlways(op))
|
||||
return;
|
||||
|
||||
for (Value result : op->getResults()) {
|
||||
if (hasOnlySpatialMvmVmmWeightUses(result))
|
||||
continue;
|
||||
|
||||
diagnostics.report(op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError(
|
||||
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
@@ -38,9 +47,7 @@ LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
||||
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||
});
|
||||
}
|
||||
|
||||
checkWeightsDirectlyExtracted(funcOp, diagnostics);
|
||||
|
||||
checkWeightUseChains(funcOp, diagnostics);
|
||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
|
||||
|
||||
return success(!diagnostics.hasFailure());
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+1
-1
@@ -67,7 +67,7 @@ static CoalescingReportRow getTotalRow(const CoalescingReportEntry& entry) {
|
||||
}
|
||||
|
||||
static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
|
||||
std::fstream file = openReportFile("static_memory_coalescing_report");
|
||||
std::fstream file = openReportFile("memory_coalescing_report");
|
||||
if (!file.is_open())
|
||||
return;
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ def SpatCompute : SpatOp<"compute",
|
||||
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
|
||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||
}];
|
||||
@@ -84,6 +85,7 @@ def SpatComputeBatch : SpatOp<"compute_batch",
|
||||
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
|
||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||
}];
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
@@ -35,6 +36,17 @@ void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t i
|
||||
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||
}
|
||||
|
||||
using CrossbarWeightSet = llvm::SetVector<Value, llvm::SmallVector<Value, 4>, llvm::SmallDenseSet<Value, 4>>;
|
||||
|
||||
CrossbarWeightSet collectCrossbarWeights(Region& body) {
|
||||
CrossbarWeightSet weights;
|
||||
body.walk([&](SpatVMMOp vmmOp) {
|
||||
Value weight = vmmOp.getWeight();
|
||||
weights.insert(weight);
|
||||
});
|
||||
return weights;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); }
|
||||
@@ -45,7 +57,6 @@ std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
||||
llvm::dbgs() << "Disse netanyao\n";
|
||||
auto index = std::distance(getWeights().begin(), existing);
|
||||
return {
|
||||
{*existing, *getWeightArgument(index)}
|
||||
@@ -75,6 +86,8 @@ std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigne
|
||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||
}
|
||||
|
||||
CrossbarWeightSet SpatCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||
|
||||
FailureOr<std::tuple<OpResult, SpatCompute>>
|
||||
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||
if (idx > getNumResults())
|
||||
@@ -127,7 +140,6 @@ std::optional<std::tuple<Value, BlockArgument>>
|
||||
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
||||
auto index = std::distance(getWeights().begin(), existing);
|
||||
llvm::dbgs() << "Bum bum bum bum\n";
|
||||
return {
|
||||
{*existing, *getWeightArgument(index)}
|
||||
};
|
||||
@@ -156,6 +168,8 @@ std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(un
|
||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||
}
|
||||
|
||||
CrossbarWeightSet SpatComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||
|
||||
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
|
||||
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||
if (idx > getNumResults())
|
||||
|
||||
@@ -10,6 +10,9 @@
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -239,6 +240,7 @@ void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
printer << " coreId " << coreIdAttr.getInt();
|
||||
@@ -264,6 +266,7 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<Type> outputTypes;
|
||||
int32_t crossbarWeightCount = 0;
|
||||
int32_t coreId = 0;
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
@@ -273,9 +276,14 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||
return failure();
|
||||
|
||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||
if (hasCoreId && parser.parseInteger(coreId))
|
||||
return failure();
|
||||
(void) crossbarWeightCount;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
@@ -357,6 +365,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
|
||||
|
||||
if (getNumResults() != 0) {
|
||||
printer << " shared_outs";
|
||||
@@ -395,6 +404,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<Type> outputTypes;
|
||||
int32_t crossbarWeightCount = 0;
|
||||
SmallVector<int32_t> coreIds;
|
||||
|
||||
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||
@@ -413,9 +423,14 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
if (parseBlockArgumentList(parser, outputArgs))
|
||||
return failure();
|
||||
|
||||
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||
return failure();
|
||||
|
||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||
return failure();
|
||||
(void) crossbarWeightCount;
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|
||||
|
||||
@@ -50,10 +50,9 @@ static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind) {
|
||||
for (Value weight : computeOp.getWeights()) {
|
||||
for (Value weight : computeOp.getWeights())
|
||||
if (!isCompileTimeComputable(weight))
|
||||
return computeOp.emitOpError() << kind << " weights must be statically computed from constants";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -131,11 +130,9 @@ verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgume
|
||||
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
||||
|
||||
auto offsets = sliceOp.getOffsets();
|
||||
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
||||
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
||||
if (!supported)
|
||||
for (Value offset : offsets)
|
||||
if (!isSupportedLaneOffsetExpr(offset, laneArg))
|
||||
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -155,11 +152,9 @@ static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::Paralle
|
||||
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
||||
|
||||
auto offsets = sliceOp.getOffsets();
|
||||
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
||||
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
||||
if (!supported)
|
||||
for (Value offset : offsets)
|
||||
if (!isSupportedLaneOffsetExpr(offset, laneArg))
|
||||
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
#include "mlir/Analysis/TopologicalSortUtils.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
@@ -34,6 +32,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "MaterializeMergeSchedule.hpp"
|
||||
#include "Scheduling/ComputeGraph.hpp"
|
||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||
#include "Scheduling/MergeSchedulingAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||
@@ -282,11 +281,8 @@ void emitMotifProfile(func::FuncOp funcOp) {
|
||||
|
||||
for (auto [index, compute] : llvm::enumerate(computes)) {
|
||||
ComputeMotifInfo& info = computeInfos[index];
|
||||
for (Operation& op : compute.getBody().front()) {
|
||||
info.instructionCount++;
|
||||
if (isa<spatial::SpatVMMOp>(&op))
|
||||
info.weightedVmmCount++;
|
||||
}
|
||||
info.instructionCount = spatial::countComputeBodyInstructions(compute.getBody());
|
||||
compute.getBody().walk([&](spatial::SpatVMMOp) { info.weightedVmmCount++; });
|
||||
if (info.weightedVmmCount > 0) {
|
||||
weightedVmmNodeCount++;
|
||||
weightedVmmOpCount += info.weightedVmmCount;
|
||||
@@ -480,7 +476,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
struct ReportRow {
|
||||
uint64_t id = 0;
|
||||
uint64_t logicalComputeCount = 0;
|
||||
uint64_t weightCount = 0;
|
||||
uint64_t crossbarCount = 0;
|
||||
uint64_t instructionCount = 0;
|
||||
bool isRebatched = false;
|
||||
SmallVector<int32_t> coreIds;
|
||||
@@ -490,38 +486,40 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
uint64_t totalLogicalComputes = 0;
|
||||
uint64_t totalBatchComputeOps = 0;
|
||||
uint64_t totalInstructionCount = 0;
|
||||
uint64_t totalWeightCount = 0;
|
||||
uint64_t totalCrossbarCount = 0;
|
||||
uint64_t nextBatchId = 0;
|
||||
std::vector<ReportRow> collectedData;
|
||||
|
||||
auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t {
|
||||
return static_cast<uint64_t>(spatial::collectDistinctCrossbarWeights(op).size());
|
||||
};
|
||||
|
||||
for (Operation& op : funcOp.getBody().front()) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||
uint64_t numInst = 0;
|
||||
for (auto& _ : spatCompute.getRegion().front())
|
||||
++numInst;
|
||||
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
|
||||
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
|
||||
SmallVector<int32_t> coreIds;
|
||||
if (auto coreId = getComputeCoreId(spatCompute))
|
||||
coreIds.push_back(*coreId);
|
||||
collectedData.push_back({totalComputeOps++, 1, spatCompute.getWeights().size(), numInst, false, coreIds});
|
||||
collectedData.push_back({totalComputeOps++, 1, perInstanceCrossbarCount, numInst, false, coreIds});
|
||||
totalLogicalComputes += 1;
|
||||
totalInstructionCount += numInst;
|
||||
totalWeightCount += spatCompute.getWeights().size();
|
||||
totalCrossbarCount += perInstanceCrossbarCount;
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
||||
uint64_t numInst = 0;
|
||||
for (auto& _ : batch.getRegion().front())
|
||||
++numInst;
|
||||
uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody());
|
||||
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
|
||||
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
|
||||
SmallVector<int32_t> coreIds;
|
||||
if (auto coreIdsAttr = batch->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
llvm::append_range(coreIds, coreIdsAttr.asArrayRef());
|
||||
collectedData.push_back({nextBatchId++, logicalCount, batch.getWeights().size(), numInst, true, coreIds});
|
||||
collectedData.push_back({nextBatchId++, logicalCount, perInstanceCrossbarCount * logicalCount, numInst, true, coreIds});
|
||||
totalComputeOps += 1;
|
||||
totalLogicalComputes += logicalCount;
|
||||
totalBatchComputeOps += 1;
|
||||
totalInstructionCount += numInst * logicalCount;
|
||||
totalWeightCount += batch.getWeights().size();
|
||||
totalCrossbarCount += perInstanceCrossbarCount * logicalCount;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -531,7 +529,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
{"Number of logical computes", std::to_string(totalLogicalComputes) },
|
||||
{"Number of top-level batch compute ops", std::to_string(totalBatchComputeOps) },
|
||||
{"Number of instructions", std::to_string(totalInstructionCount)},
|
||||
{"Number of used crossbars", std::to_string(totalWeightCount) }
|
||||
{"Number of used crossbars", std::to_string(totalCrossbarCount) }
|
||||
};
|
||||
printReportTotalsBlock(os, totalFields);
|
||||
if (!collectedData.empty())
|
||||
@@ -545,7 +543,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
|
||||
for (uint64_t nI = cI + 1; nI < totalComputeOps; ++nI) {
|
||||
ReportRow next = collectedData[nI];
|
||||
if (current.isRebatched == next.isRebatched && current.weightCount == next.weightCount
|
||||
if (current.isRebatched == next.isRebatched && current.crossbarCount == next.crossbarCount
|
||||
&& current.instructionCount == next.instructionCount
|
||||
&& current.logicalComputeCount == next.logicalComputeCount)
|
||||
lastIndex = nI;
|
||||
@@ -578,20 +576,20 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
os << ":\n";
|
||||
uint64_t perCoreLogicalComputeCount = current.isRebatched ? 1 : current.logicalComputeCount;
|
||||
uint64_t perCoreInstructionCount = current.instructionCount;
|
||||
uint64_t perCoreWeightCount =
|
||||
current.logicalComputeCount == 0 ? 0 : current.weightCount / current.logicalComputeCount;
|
||||
uint64_t perCoreCrossbarCount =
|
||||
current.logicalComputeCount == 0 ? 0 : current.crossbarCount / current.logicalComputeCount;
|
||||
uint64_t totalEntryInstructionCount = current.instructionCount * current.logicalComputeCount;
|
||||
|
||||
llvm::SmallVector<ReportField, 3> perCoreFields = {
|
||||
{"Number of logical computes", std::to_string(perCoreLogicalComputeCount)},
|
||||
{"Number of instructions", std::to_string(perCoreInstructionCount) },
|
||||
{"Number of used crossbars", std::to_string(perCoreWeightCount) }
|
||||
{"Number of used crossbars", std::to_string(perCoreCrossbarCount) }
|
||||
};
|
||||
if (current.isRebatched) {
|
||||
llvm::SmallVector<ReportField, 3> totalEntryFields = {
|
||||
{"Number of logical computes", std::to_string(current.logicalComputeCount)},
|
||||
{"Number of instructions", std::to_string(totalEntryInstructionCount) },
|
||||
{"Number of used crossbars", std::to_string(current.weightCount) }
|
||||
{"Number of used crossbars", std::to_string(current.crossbarCount) }
|
||||
};
|
||||
printReportPerCoreAndTotalFields(os, perCoreFields, totalEntryFields);
|
||||
}
|
||||
@@ -655,7 +653,7 @@ public:
|
||||
}
|
||||
emitMergeIrCounts("final-post-merge", func);
|
||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
|
||||
generateReport(func, "dcp_merge_report", analysisResult->cpuToLastComputeMap.size());
|
||||
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Unit.h"
|
||||
@@ -9,12 +13,14 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "ComputeInstanceUtils.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -22,15 +28,42 @@ namespace spatial {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
uint64_t countComputeBodyInstructions(Region& body);
|
||||
uint64_t countComputeBodyOperationInstances(Region& body);
|
||||
|
||||
namespace {
|
||||
|
||||
Weight getComputeBodyWeight(Region& body) {
|
||||
constexpr Weight kOperationWeight = 100;
|
||||
Weight numOperations = 0;
|
||||
for (auto& block : body)
|
||||
for ([[maybe_unused]] auto& op : block)
|
||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||
return checkedMultiply(numOperations, kOperationWeight);
|
||||
Cost getComputeBodyCost(Region& body) {
|
||||
constexpr Cost kOperationCost = 100;
|
||||
return checkedMultiply(static_cast<Cost>(countComputeBodyOperationInstances(body)), kOperationCost);
|
||||
}
|
||||
|
||||
std::optional<uint64_t> getStaticTripCount(scf::ForOp loop) {
|
||||
auto lb = getConstantIntValue(loop.getLowerBound());
|
||||
auto ub = getConstantIntValue(loop.getUpperBound());
|
||||
auto step = getConstantIntValue(loop.getStep());
|
||||
if (!lb || !ub || !step || *step <= 0)
|
||||
return std::nullopt;
|
||||
if (*ub <= *lb)
|
||||
return 0;
|
||||
|
||||
uint64_t distance = static_cast<uint64_t>(*ub - *lb);
|
||||
uint64_t stride = static_cast<uint64_t>(*step);
|
||||
return (distance + stride - 1) / stride;
|
||||
}
|
||||
|
||||
uint64_t countOperationInstances(Operation& op) {
|
||||
if (auto loop = dyn_cast<scf::ForOp>(&op)) {
|
||||
std::optional<uint64_t> tripCount = getStaticTripCount(loop);
|
||||
if (!tripCount)
|
||||
return 1;
|
||||
return checkedMultiply(countComputeBodyOperationInstances(loop.getRegion()), *tripCount);
|
||||
}
|
||||
|
||||
uint64_t instances = 1;
|
||||
for (Region& region : op.getRegions())
|
||||
instances = checkedAdd(instances, countComputeBodyOperationInstances(region));
|
||||
return instances;
|
||||
}
|
||||
|
||||
bool isUsedAsWeightOnly(Operation* producerOp) {
|
||||
@@ -61,7 +94,7 @@ bool isLaneOffset(OpFoldResult offset, Value laneArg) {
|
||||
return offsetValue == laneArg;
|
||||
}
|
||||
|
||||
std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
|
||||
std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
|
||||
auto inputIt = llvm::find(batch.getInputs(), input);
|
||||
if (inputIt == batch.getInputs().end())
|
||||
return std::nullopt;
|
||||
@@ -72,7 +105,7 @@ std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch,
|
||||
if (!inputArg || !laneArg)
|
||||
return std::nullopt;
|
||||
|
||||
Weight projectedCost = 0;
|
||||
Cost projectedCost = 0;
|
||||
for (Operation* user : inputArg->getUsers()) {
|
||||
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
|
||||
if (!extract || extract.getSource() != *inputArg)
|
||||
@@ -83,7 +116,7 @@ std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch,
|
||||
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return std::nullopt;
|
||||
projectedCost = checkedAdd(projectedCost, static_cast<Weight>(getSizeInBytes(resultType)));
|
||||
projectedCost = checkedAdd(projectedCost, static_cast<Cost>(getSizeInBytes(resultType)));
|
||||
}
|
||||
|
||||
if (projectedCost == 0)
|
||||
@@ -91,28 +124,286 @@ std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch,
|
||||
return projectedCost;
|
||||
}
|
||||
|
||||
Weight getInputTransferCost(const ComputeInstance& consumerInstance, Value input) {
|
||||
Cost getInputTransferCost(const ComputeInstance& consumerInstance, Value input) {
|
||||
auto inputType = cast<ShapedType>(input.getType());
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(consumerInstance.op))
|
||||
if (std::optional<Weight> projectedCost = getBatchProjectedInputTransferCost(batch, input))
|
||||
if (std::optional<Cost> projectedCost = getBatchProjectedInputTransferCost(batch, input))
|
||||
return *projectedCost;
|
||||
return static_cast<Weight>(getSizeInBytes(inputType));
|
||||
return static_cast<Cost>(getSizeInBytes(inputType));
|
||||
}
|
||||
|
||||
static CrossbarWeight getOpaqueCrossbarWeight(Value value, std::optional<uint32_t> lane) {
|
||||
CrossbarWeight weight;
|
||||
weight.opaqueValue = value;
|
||||
weight.opaqueLane = lane.value_or(std::numeric_limits<uint32_t>::max());
|
||||
return weight;
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
|
||||
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
|
||||
return constant.getValue();
|
||||
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
|
||||
unsigned position = dim.getPosition();
|
||||
if (position >= dims.size())
|
||||
return failure();
|
||||
return dims[position];
|
||||
}
|
||||
if (auto symbol = dyn_cast<AffineSymbolExpr>(expr)) {
|
||||
unsigned position = symbol.getPosition();
|
||||
if (position >= symbols.size())
|
||||
return failure();
|
||||
return symbols[position];
|
||||
}
|
||||
|
||||
auto binary = dyn_cast<AffineBinaryOpExpr>(expr);
|
||||
if (!binary)
|
||||
return failure();
|
||||
|
||||
FailureOr<int64_t> lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols);
|
||||
FailureOr<int64_t> rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
|
||||
auto floorDiv = [](int64_t value, int64_t divisor) -> FailureOr<int64_t> {
|
||||
if (divisor <= 0)
|
||||
return failure();
|
||||
if (value >= 0)
|
||||
return value / divisor;
|
||||
return -((-value + divisor - 1) / divisor);
|
||||
};
|
||||
|
||||
switch (binary.getKind()) {
|
||||
case AffineExprKind::Add: return *lhs + *rhs;
|
||||
case AffineExprKind::Mul: return *lhs * *rhs;
|
||||
case AffineExprKind::FloorDiv: return floorDiv(*lhs, *rhs);
|
||||
case AffineExprKind::CeilDiv:
|
||||
if (*rhs <= 0)
|
||||
return failure();
|
||||
return (*lhs + *rhs - 1) / *rhs;
|
||||
case AffineExprKind::Mod: {
|
||||
FailureOr<int64_t> div = floorDiv(*lhs, *rhs);
|
||||
if (failed(div))
|
||||
return failure();
|
||||
return *lhs - *div * *rhs;
|
||||
}
|
||||
default: return failure();
|
||||
}
|
||||
}
|
||||
|
||||
static FailureOr<int64_t>
|
||||
evaluateIndexLike(Value value, const DenseMap<Value, int64_t>& bindings, std::optional<uint32_t> lane, Value laneArg);
|
||||
|
||||
static FailureOr<int64_t> evaluateIndexLike(OpFoldResult value,
|
||||
const DenseMap<Value, int64_t>& bindings,
|
||||
std::optional<uint32_t> lane,
|
||||
Value laneArg) {
|
||||
if (auto attr = llvm::dyn_cast<Attribute>(value)) {
|
||||
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||
if (!intAttr)
|
||||
return failure();
|
||||
return intAttr.getInt();
|
||||
}
|
||||
return evaluateIndexLike(llvm::cast<Value>(value), bindings, lane, laneArg);
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> evaluateIndexLike(Value value,
|
||||
const DenseMap<Value, int64_t>& bindings,
|
||||
std::optional<uint32_t> lane,
|
||||
Value laneArg) {
|
||||
if (lane && value == laneArg)
|
||||
return *lane;
|
||||
if (auto it = bindings.find(value); it != bindings.end())
|
||||
return it->second;
|
||||
|
||||
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
|
||||
return constant.value();
|
||||
|
||||
if (auto constant = value.getDefiningOp<arith::ConstantOp>())
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()))
|
||||
return intAttr.getInt();
|
||||
|
||||
if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) {
|
||||
auto constant = extract.getTensor().getDefiningOp<arith::ConstantOp>();
|
||||
auto elements = constant ? dyn_cast<ElementsAttr>(constant.getValue()) : nullptr;
|
||||
auto shapedType = elements ? dyn_cast<ShapedType>(elements.getType()) : nullptr;
|
||||
if (!elements || !shapedType || shapedType.getRank() != 1 || extract.getIndices().size() != 1)
|
||||
return failure();
|
||||
|
||||
FailureOr<int64_t> index = evaluateIndexLike(extract.getIndices().front(), bindings, lane, laneArg);
|
||||
if (failed(index) || *index < 0 || *index >= static_cast<int64_t>(elements.getNumElements()))
|
||||
return failure();
|
||||
|
||||
if (auto denseInts = dyn_cast<DenseIntElementsAttr>(elements))
|
||||
return (*(denseInts.value_begin<APInt>() + *index)).getSExtValue();
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
|
||||
if (!affineApply)
|
||||
return failure();
|
||||
|
||||
AffineMap map = affineApply.getAffineMap();
|
||||
if (map.getNumResults() != 1)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t, 4> operands;
|
||||
operands.reserve(affineApply.getMapOperands().size());
|
||||
for (Value operand : affineApply.getMapOperands()) {
|
||||
FailureOr<int64_t> folded = evaluateIndexLike(operand, bindings, lane, laneArg);
|
||||
if (failed(folded))
|
||||
return failure();
|
||||
operands.push_back(*folded);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> dims(operands.data(), map.getNumDims());
|
||||
ArrayRef<int64_t> symbols(operands.data() + map.getNumDims(), map.getNumSymbols());
|
||||
return evaluateAffineExpr(map.getResult(0), dims, symbols);
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int64_t, 4>>
|
||||
evaluateIndexList(ArrayRef<OpFoldResult> values,
|
||||
const DenseMap<Value, int64_t>& bindings,
|
||||
std::optional<uint32_t> lane,
|
||||
Value laneArg) {
|
||||
SmallVector<int64_t, 4> result;
|
||||
result.reserve(values.size());
|
||||
for (OpFoldResult value : values) {
|
||||
FailureOr<int64_t> folded = evaluateIndexLike(value, bindings, lane, laneArg);
|
||||
if (failed(folded))
|
||||
return failure();
|
||||
result.push_back(*folded);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static Value resolveCrossbarWeightRoot(Operation* owner, Value root) {
|
||||
if (auto arg = dyn_cast<BlockArgument>(root)) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(owner)) {
|
||||
for (auto [index, operand] : llvm::enumerate(compute.getWeights()))
|
||||
if (compute.getWeightArgument(index) == arg)
|
||||
return operand;
|
||||
}
|
||||
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(owner)) {
|
||||
for (auto [index, operand] : llvm::enumerate(batch.getWeights()))
|
||||
if (batch.getWeightArgument(index) == arg)
|
||||
return operand;
|
||||
}
|
||||
}
|
||||
|
||||
return root;
|
||||
}
|
||||
|
||||
static CrossbarWeight completeCrossbarWeight(Value root,
|
||||
SmallVector<int64_t, 4> offsets,
|
||||
SmallVector<int64_t, 4> sizes,
|
||||
SmallVector<int64_t, 4> strides) {
|
||||
CrossbarWeight weight;
|
||||
weight.root = root;
|
||||
if (auto constant = root.getDefiningOp<arith::ConstantOp>())
|
||||
weight.rootAttr = static_cast<Attribute>(constant.getValue());
|
||||
weight.offsets = std::move(offsets);
|
||||
weight.sizes = std::move(sizes);
|
||||
weight.strides = std::move(strides);
|
||||
return weight;
|
||||
}
|
||||
|
||||
static FailureOr<CrossbarWeight>
|
||||
getStaticCrossbarWeight(Operation* owner,
|
||||
Value value,
|
||||
const DenseMap<Value, int64_t>& bindings,
|
||||
std::optional<uint32_t> lane,
|
||||
Value laneArg) {
|
||||
if (auto extract = value.getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
FailureOr<CrossbarWeight> sourceWeight =
|
||||
getStaticCrossbarWeight(owner, extract.getSource(), bindings, lane, laneArg);
|
||||
auto offsets = evaluateIndexList(extract.getMixedOffsets(), bindings, lane, laneArg);
|
||||
auto sizes = evaluateIndexList(extract.getMixedSizes(), bindings, lane, laneArg);
|
||||
auto strides = evaluateIndexList(extract.getMixedStrides(), bindings, lane, laneArg);
|
||||
if (failed(sourceWeight) || failed(offsets) || failed(sizes) || failed(strides))
|
||||
return failure();
|
||||
|
||||
if (sourceWeight->offsets.size() != offsets->size() || sourceWeight->sizes.size() != sizes->size()
|
||||
|| sourceWeight->strides.size() != strides->size()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
for (auto [index, offset] : llvm::enumerate(*offsets)) {
|
||||
sourceWeight->offsets[index] += offset * sourceWeight->strides[index];
|
||||
sourceWeight->sizes[index] = (*sizes)[index];
|
||||
sourceWeight->strides[index] *= (*strides)[index];
|
||||
}
|
||||
return *sourceWeight;
|
||||
}
|
||||
|
||||
Value root = resolveCrossbarWeightRoot(owner, value);
|
||||
auto type = dyn_cast<ShapedType>(root.getType());
|
||||
if (!type || !type.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t, 4> offsets(type.getRank(), 0);
|
||||
SmallVector<int64_t, 4> sizes(type.getShape().begin(), type.getShape().end());
|
||||
SmallVector<int64_t, 4> strides(type.getRank(), 1);
|
||||
return completeCrossbarWeight(root, std::move(offsets), std::move(sizes), std::move(strides));
|
||||
}
|
||||
|
||||
static void addCrossbarWeight(CrossbarUsage& usage, CrossbarWeight weight) {
|
||||
if (!containsCrossbarWeight(usage, weight))
|
||||
usage.push_back(std::move(weight));
|
||||
}
|
||||
|
||||
static void collectCrossbarWeightsFromOp(Operation* op,
|
||||
Operation* owner,
|
||||
DenseMap<Value, int64_t>& bindings,
|
||||
CrossbarUsage& usage,
|
||||
Value laneArg,
|
||||
std::optional<uint32_t> lane) {
|
||||
if (auto loop = dyn_cast<scf::ForOp>(op)) {
|
||||
auto lb = getConstantIntValue(loop.getLowerBound());
|
||||
auto ub = getConstantIntValue(loop.getUpperBound());
|
||||
auto step = getConstantIntValue(loop.getStep());
|
||||
if (!lb || !ub || !step || *step <= 0)
|
||||
return;
|
||||
|
||||
for (int64_t iv = *lb; iv < *ub; iv += *step) {
|
||||
bindings[loop.getInductionVar()] = iv;
|
||||
for (Operation& nested : loop.getBody()->without_terminator())
|
||||
collectCrossbarWeightsFromOp(&nested, owner, bindings, usage, laneArg, lane);
|
||||
}
|
||||
bindings.erase(loop.getInductionVar());
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto vmm = dyn_cast<SpatVMMOp>(op)) {
|
||||
FailureOr<CrossbarWeight> weight = getStaticCrossbarWeight(owner, vmm.getWeight(), bindings, lane, laneArg);
|
||||
if (failed(weight)) {
|
||||
addCrossbarWeight(usage, getOpaqueCrossbarWeight(vmm.getWeight(), lane));
|
||||
return;
|
||||
}
|
||||
addCrossbarWeight(usage, *weight);
|
||||
return;
|
||||
}
|
||||
|
||||
for (Region& region : op->getRegions())
|
||||
for (Block& block : region)
|
||||
for (Operation& nested : block.without_terminator())
|
||||
collectCrossbarWeightsFromOp(&nested, owner, bindings, usage, laneArg, lane);
|
||||
}
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
|
||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
llvm::DenseMap<std::pair<size_t, size_t>, Cost> edgeCosts;
|
||||
for (const ComputeGraphEdge& edge : edges) {
|
||||
if (edge.source == edge.target)
|
||||
continue;
|
||||
auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost);
|
||||
auto inserted = edgeCosts.try_emplace({edge.source, edge.target}, edge.transferCost);
|
||||
if (!inserted.second)
|
||||
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
|
||||
}
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregatedEdges;
|
||||
aggregatedEdges.reserve(edgeWeights.size());
|
||||
for (const auto& [key, weight] : edgeWeights)
|
||||
aggregatedEdges.push_back({key.first, key.second, weight});
|
||||
aggregatedEdges.reserve(edgeCosts.size());
|
||||
for (const auto& [key, cost] : edgeCosts)
|
||||
aggregatedEdges.push_back({key.first, key.second, cost});
|
||||
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge& lhs, const ComputeGraphEdge& rhs) {
|
||||
if (lhs.source != rhs.source)
|
||||
return lhs.source < rhs.source;
|
||||
@@ -123,30 +414,75 @@ std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> ed
|
||||
|
||||
} // namespace
|
||||
|
||||
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
|
||||
uint64_t countComputeBodyInstructions(Region& body) {
|
||||
uint64_t numOperations = 0;
|
||||
body.walk([&](Operation* op) { numOperations = checkedAdd(numOperations, static_cast<uint64_t>(1)); });
|
||||
return numOperations;
|
||||
}
|
||||
|
||||
uint64_t countComputeBodyOperationInstances(Region& body) {
|
||||
uint64_t instances = 0;
|
||||
for (Block& block : body)
|
||||
for (Operation& op : block)
|
||||
instances = checkedAdd(instances, countOperationInstances(op));
|
||||
return instances;
|
||||
}
|
||||
|
||||
CrossbarUsage collectDistinctCrossbarWeights(Operation* owner, std::optional<uint32_t> lane) {
|
||||
CrossbarUsage usage;
|
||||
DenseMap<Value, int64_t> bindings;
|
||||
Value laneArg;
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(owner))
|
||||
if (auto maybeLaneArg = batch.getLaneArgument())
|
||||
laneArg = *maybeLaneArg;
|
||||
|
||||
for (Region& region : owner->getRegions())
|
||||
for (Block& block : region)
|
||||
for (Operation& op : block.without_terminator())
|
||||
collectCrossbarWeightsFromOp(&op, owner, bindings, usage, laneArg, lane);
|
||||
return usage;
|
||||
}
|
||||
|
||||
Cost getComputeInstanceCost(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getComputeBodyWeight(spatCompute.getBody());
|
||||
return getComputeBodyCost(spatCompute.getBody());
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
||||
return checkedMultiply(getComputeBodyCost(batch.getBody()), static_cast<Cost>(instance.laneCount));
|
||||
}
|
||||
|
||||
CrossbarUsage getSpatComputeCrossbarUsage(SpatCompute spatComute){
|
||||
CrossbarUsage ret;
|
||||
ret.insert_range(spatComute.getWeights());
|
||||
return ret;
|
||||
bool containsCrossbarWeight(ArrayRef<CrossbarWeight> usage, const CrossbarWeight& weight) {
|
||||
return llvm::is_contained(usage, weight);
|
||||
}
|
||||
|
||||
CrossbarUsage getSpatComputeBatchCrossbarUsage(SpatComputeBatch spatComuteBatch){
|
||||
CrossbarUsage ret;
|
||||
ret.insert_range(spatComuteBatch.getWeights());
|
||||
return ret;
|
||||
unsigned countCrossbarOverlap(ArrayRef<CrossbarWeight> lhs, ArrayRef<CrossbarWeight> rhs) {
|
||||
unsigned overlap = 0;
|
||||
for (const CrossbarWeight& weight : rhs)
|
||||
if (containsCrossbarWeight(lhs, weight))
|
||||
++overlap;
|
||||
return overlap;
|
||||
}
|
||||
|
||||
size_t getCrossbarUnionSize(ArrayRef<CrossbarWeight> lhs, ArrayRef<CrossbarWeight> rhs) {
|
||||
size_t size = lhs.size();
|
||||
for (const CrossbarWeight& weight : rhs)
|
||||
if (!containsCrossbarWeight(lhs, weight))
|
||||
++size;
|
||||
return size;
|
||||
}
|
||||
|
||||
void insertCrossbarWeights(CrossbarUsage& usage, ArrayRef<CrossbarWeight> weights) {
|
||||
for (const CrossbarWeight& weight : weights)
|
||||
addCrossbarWeight(usage, weight);
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeCrossbarUsage(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return getSpatComputeBatchCrossbarUsage(batch);
|
||||
CrossbarUsage usage;
|
||||
if (isa<SpatCompute>(instance.op))
|
||||
return collectDistinctCrossbarWeights(instance.op);
|
||||
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
insertCrossbarWeights(usage, collectDistinctCrossbarWeights(instance.op, lane));
|
||||
return usage;
|
||||
}
|
||||
|
||||
ComputeGraph buildComputeGraph(Operation* entryOp) {
|
||||
@@ -161,7 +497,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
|
||||
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
|
||||
size_t index = graph.nodes.size();
|
||||
graph.nodes.push_back(
|
||||
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||
{instance, getComputeInstanceCost(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||
graph.instanceToIndex[instance] = index;
|
||||
continue;
|
||||
}
|
||||
@@ -173,7 +509,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
|
||||
ComputeInstance instance = getBatchChunkForIndex(batch, chunkIndex);
|
||||
size_t index = graph.nodes.size();
|
||||
graph.nodes.push_back(
|
||||
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||
{instance, getComputeInstanceCost(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||
graph.instanceToIndex[instance] = index;
|
||||
}
|
||||
}
|
||||
@@ -185,7 +521,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
|
||||
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
||||
llvm::SmallVector<Value, 4> inputs = getComputeInstanceInputs(node.instance);
|
||||
for (Value input : inputs) {
|
||||
Weight transferCost = getInputTransferCost(node.instance, input);
|
||||
Cost transferCost = getInputTransferCost(node.instance, input);
|
||||
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
|
||||
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
|
||||
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
|
||||
@@ -208,7 +544,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
|
||||
}
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregatedEdges = aggregateEdges(rawEdges);
|
||||
graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end());
|
||||
graph.edges.insert(graph.edges.end(), aggregatedEdges.begin(), aggregatedEdges.end());
|
||||
graph.successors.assign(graph.nodes.size(), {});
|
||||
graph.predecessors.assign(graph.nodes.size(), {});
|
||||
for (const ComputeGraphEdge& edge : graph.edges) {
|
||||
@@ -233,8 +569,8 @@ bool verifyAcyclic(const ComputeGraph& graph) {
|
||||
size_t node = readyNodes.front();
|
||||
readyNodes.pop();
|
||||
++visited;
|
||||
for (const auto& [child, weight] : graph.successors[node]) {
|
||||
(void) weight;
|
||||
for (const auto& [child, cost] : graph.successors[node]) {
|
||||
(void) cost;
|
||||
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
||||
if (--remainingParents[child] == 0)
|
||||
readyNodes.push(child);
|
||||
|
||||
@@ -1,52 +1,72 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "Utils.hpp"
|
||||
#include "ComputeInstance.hpp"
|
||||
#include "ComputeInstanceUtils.hpp"
|
||||
#include "Utils.hpp"
|
||||
|
||||
using CrossbarUsage = llvm::SmallPtrSet<mlir::Value, 6>;
|
||||
struct CrossbarWeight {
|
||||
mlir::Value root;
|
||||
mlir::Attribute rootAttr;
|
||||
llvm::SmallVector<int64_t, 4> offsets;
|
||||
llvm::SmallVector<int64_t, 4> sizes;
|
||||
llvm::SmallVector<int64_t, 4> strides;
|
||||
mlir::Value opaqueValue;
|
||||
uint32_t opaqueLane = 0;
|
||||
|
||||
bool operator==(const CrossbarWeight& other) const {
|
||||
bool sameRoot = rootAttr && other.rootAttr ? rootAttr == other.rootAttr : root == other.root;
|
||||
return sameRoot && offsets == other.offsets && sizes == other.sizes && strides == other.strides
|
||||
&& opaqueValue == other.opaqueValue && opaqueLane == other.opaqueLane;
|
||||
}
|
||||
};
|
||||
|
||||
using CrossbarUsage = llvm::SmallVector<CrossbarWeight, 6>;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct ComputeGraphNode {
|
||||
ComputeInstance instance;
|
||||
Weight weight = 0;
|
||||
llvm::SmallPtrSet<mlir::Value,6> crossbarUsage;
|
||||
Cost cost = 0;
|
||||
CrossbarUsage crossbarUsage;
|
||||
size_t originalOrder = 0;
|
||||
};
|
||||
|
||||
struct ComputeGraphEdge {
|
||||
size_t source = 0;
|
||||
size_t target = 0;
|
||||
Weight transferCost = 0;
|
||||
Cost transferCost = 0;
|
||||
};
|
||||
|
||||
struct ComputeGraph {
|
||||
llvm::SmallVector<ComputeGraphNode> nodes;
|
||||
llvm::SmallVector<ComputeGraphEdge> edges;
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> successors;
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> predecessors;
|
||||
std::vector<ComputeGraphNode> nodes;
|
||||
std::vector<ComputeGraphEdge> edges;
|
||||
std::vector<std::vector<std::pair<size_t, Cost>>> successors;
|
||||
std::vector<std::vector<std::pair<size_t, Cost>>> predecessors;
|
||||
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
||||
};
|
||||
|
||||
ComputeGraph buildComputeGraph(mlir::Operation* entryOp);
|
||||
bool verifyAcyclic(const ComputeGraph& graph);
|
||||
|
||||
Weight getComputeInstanceWeight(const ComputeInstance& instance);
|
||||
uint64_t countComputeBodyInstructions(mlir::Region& body);
|
||||
uint64_t countComputeBodyOperationInstances(mlir::Region& body);
|
||||
Cost getComputeInstanceCost(const ComputeInstance& instance);
|
||||
CrossbarUsage collectDistinctCrossbarWeights(mlir::Operation* owner, std::optional<uint32_t> lane = std::nullopt);
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance);
|
||||
bool containsCrossbarWeight(llvm::ArrayRef<CrossbarWeight> usage, const CrossbarWeight& weight);
|
||||
unsigned countCrossbarOverlap(llvm::ArrayRef<CrossbarWeight> lhs, llvm::ArrayRef<CrossbarWeight> rhs);
|
||||
size_t getCrossbarUnionSize(llvm::ArrayRef<CrossbarWeight> lhs, llvm::ArrayRef<CrossbarWeight> rhs);
|
||||
void insertCrossbarWeights(CrossbarUsage& usage, llvm::ArrayRef<CrossbarWeight> weights);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
+4
-4
@@ -48,12 +48,12 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result
|
||||
return lhs.second < rhs.second;
|
||||
});
|
||||
|
||||
unsigned int usedCrossbars = 0;
|
||||
CrossbarUsage usedCrossbars;
|
||||
for (size_t slot = 0; slot < scheduledTasks.size(); ++slot) {
|
||||
if (scheduledTasks[slot].first != slot)
|
||||
llvm::report_fatal_error("merge scheduling: CPU slots are not contiguous");
|
||||
usedCrossbars = addOrMax(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage.size());
|
||||
if (usedCrossbars > crossbarCapacity)
|
||||
insertCrossbarWeights(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage);
|
||||
if (usedCrossbars.size() > crossbarCapacity)
|
||||
llvm::report_fatal_error("merge scheduling: CPU crossbar capacity exceeded");
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result
|
||||
if (sourceCpu == targetCpu && sourceSlot >= targetSlot)
|
||||
llvm::report_fatal_error("merge scheduling: same-CPU dependency order is invalid");
|
||||
|
||||
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].weight);
|
||||
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].cost);
|
||||
if (sourceCpu != targetCpu)
|
||||
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
|
||||
if (targetStart < earliestTargetStart) {
|
||||
|
||||
@@ -89,8 +89,8 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
||||
|
||||
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
||||
// If graph.nodes[task] is modified to hold a vector of weights per processor, access it here.
|
||||
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].weight; };
|
||||
// If graph.nodes[task] is modified to hold a vector of costs per processor, access it here.
|
||||
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].cost; };
|
||||
|
||||
std::vector<Time> oct(nodeCount * processorCount, 0);
|
||||
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
||||
@@ -163,7 +163,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
}
|
||||
|
||||
std::vector<char> scheduled(nodeCount, false);
|
||||
std::vector<CrossbarUsage> processorCrossbars(processorCount, llvm::SmallPtrSet<mlir::Value, 6> {});
|
||||
std::vector<CrossbarUsage> processorCrossbars(processorCount);
|
||||
std::vector<ScheduledTask> schedules(nodeCount);
|
||||
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
|
||||
|
||||
@@ -178,25 +178,17 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
Time bestEst = 0;
|
||||
Time bestEft = 0;
|
||||
Time bestOeft = std::numeric_limits<Time>::max();
|
||||
unsigned int bestOverlapWeight = 0;
|
||||
unsigned int bestOverlapCount = 0;
|
||||
bool crossbarRejected = false;
|
||||
|
||||
auto crossbarsAreContainedInProcessor = [&processorCrossbars](mlir::Value nodeCrossbar, size_t processor) {
|
||||
return llvm::is_contained(processorCrossbars[processor], nodeCrossbar);
|
||||
};
|
||||
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
auto crossbarsAreContained = std::bind(crossbarsAreContainedInProcessor, std::placeholders::_1, processor);
|
||||
if (graph.nodes[task].crossbarUsage.size() != 0
|
||||
&& !llvm::all_of(graph.nodes[task].crossbarUsage, crossbarsAreContained)
|
||||
&& addOrMax(processorCrossbars[processor].size(), graph.nodes[task].crossbarUsage.size())
|
||||
unsigned int overlapCount = countCrossbarOverlap(processorCrossbars[processor], graph.nodes[task].crossbarUsage);
|
||||
if (!graph.nodes[task].crossbarUsage.empty()
|
||||
&& getCrossbarUnionSize(processorCrossbars[processor], graph.nodes[task].crossbarUsage)
|
||||
> options.crossbarCapacity) {
|
||||
|
||||
crossbarRejected = true;
|
||||
continue;
|
||||
}
|
||||
unsigned int overlapWeight =
|
||||
llvm::count_if(graph.nodes[task].crossbarUsage, crossbarsAreContained);
|
||||
crossbarRejected = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
Time dataReady = 0;
|
||||
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
||||
@@ -206,7 +198,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
}
|
||||
|
||||
// 2. PEFT Gap-Filling EST Calculation (Maintains optimal scheduling math)
|
||||
Time compWeight = getComputeCost(task, processor);
|
||||
Time computeCost = getComputeCost(task, processor);
|
||||
Time est = dataReady;
|
||||
Time currentEnd = 0;
|
||||
bool foundGap = false;
|
||||
@@ -215,7 +207,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
const ScheduledTask& schedTask = schedules[schedTaskIndex];
|
||||
Time gapStart = std::max(currentEnd, dataReady);
|
||||
|
||||
if (addOrMax(gapStart, compWeight) <= schedTask.startTime) {
|
||||
if (addOrMax(gapStart, computeCost) <= schedTask.startTime) {
|
||||
est = gapStart;
|
||||
foundGap = true;
|
||||
break;
|
||||
@@ -226,7 +218,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
if (!foundGap)
|
||||
est = std::max(currentEnd, dataReady);
|
||||
|
||||
Time eft = addOrMax(est, compWeight);
|
||||
Time eft = addOrMax(est, computeCost);
|
||||
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||
|
||||
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
||||
@@ -235,14 +227,14 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
bestEst = est;
|
||||
bestEft = eft;
|
||||
bestOeft = oeft;
|
||||
bestOverlapWeight = overlapWeight;
|
||||
bestOverlapCount = overlapCount;
|
||||
}
|
||||
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapWeight < bestOverlapWeight) {
|
||||
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapCount < bestOverlapCount) {
|
||||
bestProcessor = processor;
|
||||
bestEst = est;
|
||||
bestEft = eft;
|
||||
bestOeft = oeft;
|
||||
bestOverlapWeight = overlapWeight;
|
||||
bestOverlapCount = overlapCount;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,7 +257,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
schedules[task] = {bestProcessor, bestEst, bestEft};
|
||||
scheduled[task] = true;
|
||||
++scheduledCount;
|
||||
processorCrossbars[bestProcessor].insert_range(graph.nodes[task].crossbarUsage);
|
||||
insertCrossbarWeights(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
|
||||
|
||||
// 3. CRITICAL FIX: Topological Append
|
||||
// Because the readyQueue pops in strict topological order, simply pushing to the
|
||||
@@ -319,37 +311,6 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
}
|
||||
}
|
||||
}
|
||||
/*{
|
||||
llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n";
|
||||
std::vector<bool> visited(processorCount, false);
|
||||
size_t uniqueClassCount = 0;
|
||||
|
||||
for (size_t i = 0; i < processorCount; ++i) {
|
||||
if (visited[i])
|
||||
continue;
|
||||
|
||||
// We found a new unique schedule (equivalence class)
|
||||
++uniqueClassCount;
|
||||
visited[i] = true;
|
||||
|
||||
llvm::dbgs() << "Class " << uniqueClassCount << ": CPUs { " << i;
|
||||
|
||||
// Find and mark all identical companions
|
||||
auto it = equivalentClass.find(i);
|
||||
if (it != equivalentClass.end()) {
|
||||
for (size_t eqCpu : it->second) {
|
||||
if (!visited[eqCpu]) {
|
||||
llvm::dbgs() << ", " << eqCpu;
|
||||
visited[eqCpu] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
llvm::dbgs() << " }\n";
|
||||
}
|
||||
|
||||
llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n";
|
||||
llvm::dbgs() << "--------------------------------------\n";
|
||||
}*/
|
||||
|
||||
// 6. Populate Final Result
|
||||
MergeScheduleResult result;
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using CPU = int;
|
||||
using Weight = unsigned long long;
|
||||
using Cost = unsigned long long;
|
||||
using Time = unsigned long long;
|
||||
|
||||
|
||||
@@ -55,4 +55,3 @@ inline T subtractOrZero(T lhs, T rhs) {
|
||||
}
|
||||
|
||||
inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user