This commit is contained in:
@@ -40,6 +40,21 @@ Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, O
|
|||||||
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, RewriterBase& rewriter) {
|
||||||
|
assert(anchorOp && "expected a valid anchor operation");
|
||||||
|
Block* hostBlock = getHostConstantBlock(anchorOp);
|
||||||
|
for (Operation& op : *hostBlock) {
|
||||||
|
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
||||||
|
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
||||||
|
continue;
|
||||||
|
return constantOp.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(hostBlock);
|
||||||
|
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
|
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
|
||||||
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
|
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
|
||||||
}
|
}
|
||||||
@@ -49,6 +64,11 @@ Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, Operation
|
|||||||
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
|
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, RewriterBase& rewriter) {
|
||||||
|
Builder builder(anchorOp->getContext());
|
||||||
|
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
|
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
|
||||||
Builder builder(anchorOp->getContext());
|
Builder builder(anchorOp->getContext());
|
||||||
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
|
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Transforms/FoldUtils.h"
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
|
||||||
@@ -17,10 +14,17 @@ mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
|
|||||||
mlir::Type type,
|
mlir::Type type,
|
||||||
mlir::OperationFolder& folder);
|
mlir::OperationFolder& folder);
|
||||||
|
|
||||||
|
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
|
||||||
|
mlir::Attribute value,
|
||||||
|
mlir::Type type,
|
||||||
|
mlir::RewriterBase& rewriter);
|
||||||
|
|
||||||
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
|
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
|
||||||
|
|
||||||
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||||
|
|
||||||
|
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::RewriterBase& rewriter);
|
||||||
|
|
||||||
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
|
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
|
||||||
|
|
||||||
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
@@ -12,13 +13,12 @@
|
|||||||
|
|
||||||
#include "Common/Common.hpp"
|
#include "Common/Common.hpp"
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -118,6 +118,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
preTarget.addLegalDialect<spatial::SpatialDialect,
|
preTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
ONNXDialect,
|
ONNXDialect,
|
||||||
tensor::TensorDialect,
|
tensor::TensorDialect,
|
||||||
|
affine::AffineDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
scf::SCFDialect>();
|
scf::SCFDialect>();
|
||||||
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
|
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
|
||||||
@@ -156,6 +157,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
target.addLegalDialect<spatial::SpatialDialect,
|
target.addLegalDialect<spatial::SpatialDialect,
|
||||||
ONNXDialect,
|
ONNXDialect,
|
||||||
tensor::TensorDialect,
|
tensor::TensorDialect,
|
||||||
|
affine::AffineDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
scf::SCFDialect>();
|
scf::SCFDialect>();
|
||||||
target.addIllegalOp<ONNXMatMulOp>();
|
target.addIllegalOp<ONNXMatMulOp>();
|
||||||
@@ -189,6 +191,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
ONNXDialect,
|
ONNXDialect,
|
||||||
tensor::TensorDialect,
|
tensor::TensorDialect,
|
||||||
|
affine::AffineDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
scf::SCFDialect>();
|
scf::SCFDialect>();
|
||||||
|
|
||||||
@@ -203,6 +206,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
postTarget.addLegalDialect<spatial::SpatialDialect,
|
postTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
ONNXDialect,
|
ONNXDialect,
|
||||||
tensor::TensorDialect,
|
tensor::TensorDialect,
|
||||||
|
affine::AffineDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
scf::SCFDialect>();
|
scf::SCFDialect>();
|
||||||
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
|
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
#include "Common/IR/WeightUtils.hpp"
|
#include "Common/IR/WeightUtils.hpp"
|
||||||
@@ -13,17 +11,28 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
void checkWeightsDirectlyExtracted(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
namespace {
|
||||||
for (auto extractSlice : func.getOps<tensor::ExtractSliceOp>()) {
|
|
||||||
auto source = getCompileTimeSource(extractSlice.getOperation());
|
|
||||||
if (source && hasWeightAlways(source->source) && source->chainLength > 1) {
|
|
||||||
|
|
||||||
diagnostics.report(extractSlice.getOperation(),
|
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
[](Operation* illegalOp) { illegalOp->emitOpError("Weight not directly extracted"); });
|
func.walk([&](Operation* op) {
|
||||||
}
|
if (!hasWeightAlways(op))
|
||||||
|
return;
|
||||||
|
|
||||||
|
for (Value result : op->getResults()) {
|
||||||
|
if (hasOnlySpatialMvmVmmWeightUses(result))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
diagnostics.report(op, [&](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError(
|
||||||
|
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
|
||||||
|
});
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
||||||
pim::CappedDiagnosticReporter diagnostics;
|
pim::CappedDiagnosticReporter diagnostics;
|
||||||
|
|
||||||
@@ -38,9 +47,7 @@ LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
|||||||
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
checkWeightUseChains(funcOp, diagnostics);
|
||||||
checkWeightsDirectlyExtracted(funcOp, diagnostics);
|
|
||||||
|
|
||||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
|
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
|
||||||
|
|
||||||
return success(!diagnostics.hasFailure());
|
return success(!diagnostics.hasFailure());
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
+1
-1
@@ -67,7 +67,7 @@ static CoalescingReportRow getTotalRow(const CoalescingReportEntry& entry) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
|
static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
|
||||||
std::fstream file = openReportFile("static_memory_coalescing_report");
|
std::fstream file = openReportFile("memory_coalescing_report");
|
||||||
if (!file.is_open())
|
if (!file.is_open())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ def SpatCompute : SpatOp<"compute",
|
|||||||
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||||
|
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
|
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
|
||||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||||
}];
|
}];
|
||||||
@@ -84,6 +85,7 @@ def SpatComputeBatch : SpatOp<"compute_batch",
|
|||||||
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||||
|
::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights();
|
||||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
|
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
|
||||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||||
}];
|
}];
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
@@ -35,6 +36,17 @@ void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t i
|
|||||||
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
using CrossbarWeightSet = llvm::SetVector<Value, llvm::SmallVector<Value, 4>, llvm::SmallDenseSet<Value, 4>>;
|
||||||
|
|
||||||
|
CrossbarWeightSet collectCrossbarWeights(Region& body) {
|
||||||
|
CrossbarWeightSet weights;
|
||||||
|
body.walk([&](SpatVMMOp vmmOp) {
|
||||||
|
Value weight = vmmOp.getWeight();
|
||||||
|
weights.insert(weight);
|
||||||
|
});
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); }
|
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); }
|
||||||
@@ -45,7 +57,6 @@ std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
|
|||||||
|
|
||||||
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
||||||
llvm::dbgs() << "Disse netanyao\n";
|
|
||||||
auto index = std::distance(getWeights().begin(), existing);
|
auto index = std::distance(getWeights().begin(), existing);
|
||||||
return {
|
return {
|
||||||
{*existing, *getWeightArgument(index)}
|
{*existing, *getWeightArgument(index)}
|
||||||
@@ -75,6 +86,8 @@ std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigne
|
|||||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CrossbarWeightSet SpatCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||||
|
|
||||||
FailureOr<std::tuple<OpResult, SpatCompute>>
|
FailureOr<std::tuple<OpResult, SpatCompute>>
|
||||||
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
if (idx > getNumResults())
|
if (idx > getNumResults())
|
||||||
@@ -127,7 +140,6 @@ std::optional<std::tuple<Value, BlockArgument>>
|
|||||||
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||||
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
||||||
auto index = std::distance(getWeights().begin(), existing);
|
auto index = std::distance(getWeights().begin(), existing);
|
||||||
llvm::dbgs() << "Bum bum bum bum\n";
|
|
||||||
return {
|
return {
|
||||||
{*existing, *getWeightArgument(index)}
|
{*existing, *getWeightArgument(index)}
|
||||||
};
|
};
|
||||||
@@ -156,6 +168,8 @@ std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(un
|
|||||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CrossbarWeightSet SpatComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); }
|
||||||
|
|
||||||
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
|
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
|
||||||
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||||
if (idx > getNumResults())
|
if (idx > getNumResults())
|
||||||
|
|||||||
@@ -10,6 +10,9 @@
|
|||||||
#include "mlir/IR/Types.h"
|
#include "mlir/IR/Types.h"
|
||||||
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
|
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|||||||
@@ -10,6 +10,7 @@
|
|||||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -239,6 +240,7 @@ void SpatCompute::print(OpAsmPrinter& printer) {
|
|||||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||||
|
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
|
||||||
|
|
||||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||||
printer << " coreId " << coreIdAttr.getInt();
|
printer << " coreId " << coreIdAttr.getInt();
|
||||||
@@ -264,6 +266,7 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
SmallVector<Type> weightTypes;
|
SmallVector<Type> weightTypes;
|
||||||
SmallVector<Type> inputTypes;
|
SmallVector<Type> inputTypes;
|
||||||
SmallVector<Type> outputTypes;
|
SmallVector<Type> outputTypes;
|
||||||
|
int32_t crossbarWeightCount = 0;
|
||||||
int32_t coreId = 0;
|
int32_t coreId = 0;
|
||||||
|
|
||||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||||
@@ -273,9 +276,14 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
|||||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||||
|
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||||
|
return failure();
|
||||||
|
|
||||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||||
if (hasCoreId && parser.parseInteger(coreId))
|
if (hasCoreId && parser.parseInteger(coreId))
|
||||||
return failure();
|
return failure();
|
||||||
|
(void) crossbarWeightCount;
|
||||||
|
|
||||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||||
|| parseCompressedRepeatedList(
|
|| parseCompressedRepeatedList(
|
||||||
@@ -357,6 +365,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
|||||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||||
printer << " ";
|
printer << " ";
|
||||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||||
|
printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size();
|
||||||
|
|
||||||
if (getNumResults() != 0) {
|
if (getNumResults() != 0) {
|
||||||
printer << " shared_outs";
|
printer << " shared_outs";
|
||||||
@@ -395,6 +404,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
|||||||
SmallVector<Type> weightTypes;
|
SmallVector<Type> weightTypes;
|
||||||
SmallVector<Type> inputTypes;
|
SmallVector<Type> inputTypes;
|
||||||
SmallVector<Type> outputTypes;
|
SmallVector<Type> outputTypes;
|
||||||
|
int32_t crossbarWeightCount = 0;
|
||||||
SmallVector<int32_t> coreIds;
|
SmallVector<int32_t> coreIds;
|
||||||
|
|
||||||
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||||
@@ -413,9 +423,14 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
|||||||
if (parseBlockArgumentList(parser, outputArgs))
|
if (parseBlockArgumentList(parser, outputArgs))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights");
|
||||||
|
if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount))
|
||||||
|
return failure();
|
||||||
|
|
||||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||||
return failure();
|
return failure();
|
||||||
|
(void) crossbarWeightCount;
|
||||||
|
|
||||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||||
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|
|| parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes)
|
||||||
|
|||||||
@@ -50,10 +50,9 @@ static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
|||||||
|
|
||||||
template <typename ComputeOpTy>
|
template <typename ComputeOpTy>
|
||||||
static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind) {
|
static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind) {
|
||||||
for (Value weight : computeOp.getWeights()) {
|
for (Value weight : computeOp.getWeights())
|
||||||
if (!isCompileTimeComputable(weight))
|
if (!isCompileTimeComputable(weight))
|
||||||
return computeOp.emitOpError() << kind << " weights must be statically computed from constants";
|
return computeOp.emitOpError() << kind << " weights must be statically computed from constants";
|
||||||
}
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,11 +130,9 @@ verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgume
|
|||||||
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
||||||
|
|
||||||
auto offsets = sliceOp.getOffsets();
|
auto offsets = sliceOp.getOffsets();
|
||||||
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
for (Value offset : offsets)
|
||||||
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
if (!isSupportedLaneOffsetExpr(offset, laneArg))
|
||||||
if (!supported)
|
|
||||||
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -155,11 +152,9 @@ static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::Paralle
|
|||||||
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
||||||
|
|
||||||
auto offsets = sliceOp.getOffsets();
|
auto offsets = sliceOp.getOffsets();
|
||||||
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
for (Value offset : offsets)
|
||||||
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
if (!isSupportedLaneOffsetExpr(offset, laneArg))
|
||||||
if (!supported)
|
|
||||||
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
||||||
}
|
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
#include "mlir/Analysis/TopologicalSortUtils.h"
|
#include "mlir/Analysis/TopologicalSortUtils.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
@@ -34,6 +32,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "MaterializeMergeSchedule.hpp"
|
#include "MaterializeMergeSchedule.hpp"
|
||||||
|
#include "Scheduling/ComputeGraph.hpp"
|
||||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||||
#include "Scheduling/MergeSchedulingAnalysis.hpp"
|
#include "Scheduling/MergeSchedulingAnalysis.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
|
||||||
@@ -282,11 +281,8 @@ void emitMotifProfile(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
for (auto [index, compute] : llvm::enumerate(computes)) {
|
for (auto [index, compute] : llvm::enumerate(computes)) {
|
||||||
ComputeMotifInfo& info = computeInfos[index];
|
ComputeMotifInfo& info = computeInfos[index];
|
||||||
for (Operation& op : compute.getBody().front()) {
|
info.instructionCount = spatial::countComputeBodyInstructions(compute.getBody());
|
||||||
info.instructionCount++;
|
compute.getBody().walk([&](spatial::SpatVMMOp) { info.weightedVmmCount++; });
|
||||||
if (isa<spatial::SpatVMMOp>(&op))
|
|
||||||
info.weightedVmmCount++;
|
|
||||||
}
|
|
||||||
if (info.weightedVmmCount > 0) {
|
if (info.weightedVmmCount > 0) {
|
||||||
weightedVmmNodeCount++;
|
weightedVmmNodeCount++;
|
||||||
weightedVmmOpCount += info.weightedVmmCount;
|
weightedVmmOpCount += info.weightedVmmCount;
|
||||||
@@ -480,7 +476,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
|||||||
struct ReportRow {
|
struct ReportRow {
|
||||||
uint64_t id = 0;
|
uint64_t id = 0;
|
||||||
uint64_t logicalComputeCount = 0;
|
uint64_t logicalComputeCount = 0;
|
||||||
uint64_t weightCount = 0;
|
uint64_t crossbarCount = 0;
|
||||||
uint64_t instructionCount = 0;
|
uint64_t instructionCount = 0;
|
||||||
bool isRebatched = false;
|
bool isRebatched = false;
|
||||||
SmallVector<int32_t> coreIds;
|
SmallVector<int32_t> coreIds;
|
||||||
@@ -490,38 +486,40 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
|||||||
uint64_t totalLogicalComputes = 0;
|
uint64_t totalLogicalComputes = 0;
|
||||||
uint64_t totalBatchComputeOps = 0;
|
uint64_t totalBatchComputeOps = 0;
|
||||||
uint64_t totalInstructionCount = 0;
|
uint64_t totalInstructionCount = 0;
|
||||||
uint64_t totalWeightCount = 0;
|
uint64_t totalCrossbarCount = 0;
|
||||||
uint64_t nextBatchId = 0;
|
uint64_t nextBatchId = 0;
|
||||||
std::vector<ReportRow> collectedData;
|
std::vector<ReportRow> collectedData;
|
||||||
|
|
||||||
|
auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t {
|
||||||
|
return static_cast<uint64_t>(spatial::collectDistinctCrossbarWeights(op).size());
|
||||||
|
};
|
||||||
|
|
||||||
for (Operation& op : funcOp.getBody().front()) {
|
for (Operation& op : funcOp.getBody().front()) {
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||||
uint64_t numInst = 0;
|
uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody());
|
||||||
for (auto& _ : spatCompute.getRegion().front())
|
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation());
|
||||||
++numInst;
|
|
||||||
SmallVector<int32_t> coreIds;
|
SmallVector<int32_t> coreIds;
|
||||||
if (auto coreId = getComputeCoreId(spatCompute))
|
if (auto coreId = getComputeCoreId(spatCompute))
|
||||||
coreIds.push_back(*coreId);
|
coreIds.push_back(*coreId);
|
||||||
collectedData.push_back({totalComputeOps++, 1, spatCompute.getWeights().size(), numInst, false, coreIds});
|
collectedData.push_back({totalComputeOps++, 1, perInstanceCrossbarCount, numInst, false, coreIds});
|
||||||
totalLogicalComputes += 1;
|
totalLogicalComputes += 1;
|
||||||
totalInstructionCount += numInst;
|
totalInstructionCount += numInst;
|
||||||
totalWeightCount += spatCompute.getWeights().size();
|
totalCrossbarCount += perInstanceCrossbarCount;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
||||||
uint64_t numInst = 0;
|
uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody());
|
||||||
for (auto& _ : batch.getRegion().front())
|
|
||||||
++numInst;
|
|
||||||
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
|
uint64_t logicalCount = static_cast<uint64_t>(batch.getLaneCount());
|
||||||
|
uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation());
|
||||||
SmallVector<int32_t> coreIds;
|
SmallVector<int32_t> coreIds;
|
||||||
if (auto coreIdsAttr = batch->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
if (auto coreIdsAttr = batch->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||||
llvm::append_range(coreIds, coreIdsAttr.asArrayRef());
|
llvm::append_range(coreIds, coreIdsAttr.asArrayRef());
|
||||||
collectedData.push_back({nextBatchId++, logicalCount, batch.getWeights().size(), numInst, true, coreIds});
|
collectedData.push_back({nextBatchId++, logicalCount, perInstanceCrossbarCount * logicalCount, numInst, true, coreIds});
|
||||||
totalComputeOps += 1;
|
totalComputeOps += 1;
|
||||||
totalLogicalComputes += logicalCount;
|
totalLogicalComputes += logicalCount;
|
||||||
totalBatchComputeOps += 1;
|
totalBatchComputeOps += 1;
|
||||||
totalInstructionCount += numInst * logicalCount;
|
totalInstructionCount += numInst * logicalCount;
|
||||||
totalWeightCount += batch.getWeights().size();
|
totalCrossbarCount += perInstanceCrossbarCount * logicalCount;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -531,7 +529,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
|||||||
{"Number of logical computes", std::to_string(totalLogicalComputes) },
|
{"Number of logical computes", std::to_string(totalLogicalComputes) },
|
||||||
{"Number of top-level batch compute ops", std::to_string(totalBatchComputeOps) },
|
{"Number of top-level batch compute ops", std::to_string(totalBatchComputeOps) },
|
||||||
{"Number of instructions", std::to_string(totalInstructionCount)},
|
{"Number of instructions", std::to_string(totalInstructionCount)},
|
||||||
{"Number of used crossbars", std::to_string(totalWeightCount) }
|
{"Number of used crossbars", std::to_string(totalCrossbarCount) }
|
||||||
};
|
};
|
||||||
printReportTotalsBlock(os, totalFields);
|
printReportTotalsBlock(os, totalFields);
|
||||||
if (!collectedData.empty())
|
if (!collectedData.empty())
|
||||||
@@ -545,7 +543,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
|||||||
|
|
||||||
for (uint64_t nI = cI + 1; nI < totalComputeOps; ++nI) {
|
for (uint64_t nI = cI + 1; nI < totalComputeOps; ++nI) {
|
||||||
ReportRow next = collectedData[nI];
|
ReportRow next = collectedData[nI];
|
||||||
if (current.isRebatched == next.isRebatched && current.weightCount == next.weightCount
|
if (current.isRebatched == next.isRebatched && current.crossbarCount == next.crossbarCount
|
||||||
&& current.instructionCount == next.instructionCount
|
&& current.instructionCount == next.instructionCount
|
||||||
&& current.logicalComputeCount == next.logicalComputeCount)
|
&& current.logicalComputeCount == next.logicalComputeCount)
|
||||||
lastIndex = nI;
|
lastIndex = nI;
|
||||||
@@ -578,20 +576,20 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
|||||||
os << ":\n";
|
os << ":\n";
|
||||||
uint64_t perCoreLogicalComputeCount = current.isRebatched ? 1 : current.logicalComputeCount;
|
uint64_t perCoreLogicalComputeCount = current.isRebatched ? 1 : current.logicalComputeCount;
|
||||||
uint64_t perCoreInstructionCount = current.instructionCount;
|
uint64_t perCoreInstructionCount = current.instructionCount;
|
||||||
uint64_t perCoreWeightCount =
|
uint64_t perCoreCrossbarCount =
|
||||||
current.logicalComputeCount == 0 ? 0 : current.weightCount / current.logicalComputeCount;
|
current.logicalComputeCount == 0 ? 0 : current.crossbarCount / current.logicalComputeCount;
|
||||||
uint64_t totalEntryInstructionCount = current.instructionCount * current.logicalComputeCount;
|
uint64_t totalEntryInstructionCount = current.instructionCount * current.logicalComputeCount;
|
||||||
|
|
||||||
llvm::SmallVector<ReportField, 3> perCoreFields = {
|
llvm::SmallVector<ReportField, 3> perCoreFields = {
|
||||||
{"Number of logical computes", std::to_string(perCoreLogicalComputeCount)},
|
{"Number of logical computes", std::to_string(perCoreLogicalComputeCount)},
|
||||||
{"Number of instructions", std::to_string(perCoreInstructionCount) },
|
{"Number of instructions", std::to_string(perCoreInstructionCount) },
|
||||||
{"Number of used crossbars", std::to_string(perCoreWeightCount) }
|
{"Number of used crossbars", std::to_string(perCoreCrossbarCount) }
|
||||||
};
|
};
|
||||||
if (current.isRebatched) {
|
if (current.isRebatched) {
|
||||||
llvm::SmallVector<ReportField, 3> totalEntryFields = {
|
llvm::SmallVector<ReportField, 3> totalEntryFields = {
|
||||||
{"Number of logical computes", std::to_string(current.logicalComputeCount)},
|
{"Number of logical computes", std::to_string(current.logicalComputeCount)},
|
||||||
{"Number of instructions", std::to_string(totalEntryInstructionCount) },
|
{"Number of instructions", std::to_string(totalEntryInstructionCount) },
|
||||||
{"Number of used crossbars", std::to_string(current.weightCount) }
|
{"Number of used crossbars", std::to_string(current.crossbarCount) }
|
||||||
};
|
};
|
||||||
printReportPerCoreAndTotalFields(os, perCoreFields, totalEntryFields);
|
printReportPerCoreAndTotalFields(os, perCoreFields, totalEntryFields);
|
||||||
}
|
}
|
||||||
@@ -655,7 +653,7 @@ public:
|
|||||||
}
|
}
|
||||||
emitMergeIrCounts("final-post-merge", func);
|
emitMergeIrCounts("final-post-merge", func);
|
||||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
|
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
|
||||||
generateReport(func, "dcp_merge_report", analysisResult->cpuToLastComputeMap.size());
|
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/IR/Unit.h"
|
#include "mlir/IR/Unit.h"
|
||||||
@@ -9,12 +13,14 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
#include <limits>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ComputeGraph.hpp"
|
#include "ComputeGraph.hpp"
|
||||||
|
#include "ComputeInstanceUtils.hpp"
|
||||||
#include "src/Support/TypeUtilities.hpp"
|
#include "src/Support/TypeUtilities.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -22,15 +28,42 @@ namespace spatial {
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
uint64_t countComputeBodyInstructions(Region& body);
|
||||||
|
uint64_t countComputeBodyOperationInstances(Region& body);
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
Weight getComputeBodyWeight(Region& body) {
|
Cost getComputeBodyCost(Region& body) {
|
||||||
constexpr Weight kOperationWeight = 100;
|
constexpr Cost kOperationCost = 100;
|
||||||
Weight numOperations = 0;
|
return checkedMultiply(static_cast<Cost>(countComputeBodyOperationInstances(body)), kOperationCost);
|
||||||
for (auto& block : body)
|
}
|
||||||
for ([[maybe_unused]] auto& op : block)
|
|
||||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
std::optional<uint64_t> getStaticTripCount(scf::ForOp loop) {
|
||||||
return checkedMultiply(numOperations, kOperationWeight);
|
auto lb = getConstantIntValue(loop.getLowerBound());
|
||||||
|
auto ub = getConstantIntValue(loop.getUpperBound());
|
||||||
|
auto step = getConstantIntValue(loop.getStep());
|
||||||
|
if (!lb || !ub || !step || *step <= 0)
|
||||||
|
return std::nullopt;
|
||||||
|
if (*ub <= *lb)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
uint64_t distance = static_cast<uint64_t>(*ub - *lb);
|
||||||
|
uint64_t stride = static_cast<uint64_t>(*step);
|
||||||
|
return (distance + stride - 1) / stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t countOperationInstances(Operation& op) {
|
||||||
|
if (auto loop = dyn_cast<scf::ForOp>(&op)) {
|
||||||
|
std::optional<uint64_t> tripCount = getStaticTripCount(loop);
|
||||||
|
if (!tripCount)
|
||||||
|
return 1;
|
||||||
|
return checkedMultiply(countComputeBodyOperationInstances(loop.getRegion()), *tripCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t instances = 1;
|
||||||
|
for (Region& region : op.getRegions())
|
||||||
|
instances = checkedAdd(instances, countComputeBodyOperationInstances(region));
|
||||||
|
return instances;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isUsedAsWeightOnly(Operation* producerOp) {
|
bool isUsedAsWeightOnly(Operation* producerOp) {
|
||||||
@@ -61,7 +94,7 @@ bool isLaneOffset(OpFoldResult offset, Value laneArg) {
|
|||||||
return offsetValue == laneArg;
|
return offsetValue == laneArg;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
|
std::optional<Cost> getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) {
|
||||||
auto inputIt = llvm::find(batch.getInputs(), input);
|
auto inputIt = llvm::find(batch.getInputs(), input);
|
||||||
if (inputIt == batch.getInputs().end())
|
if (inputIt == batch.getInputs().end())
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
@@ -72,7 +105,7 @@ std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch,
|
|||||||
if (!inputArg || !laneArg)
|
if (!inputArg || !laneArg)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
Weight projectedCost = 0;
|
Cost projectedCost = 0;
|
||||||
for (Operation* user : inputArg->getUsers()) {
|
for (Operation* user : inputArg->getUsers()) {
|
||||||
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
|
auto extract = dyn_cast<tensor::ExtractSliceOp>(user);
|
||||||
if (!extract || extract.getSource() != *inputArg)
|
if (!extract || extract.getSource() != *inputArg)
|
||||||
@@ -83,7 +116,7 @@ std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch,
|
|||||||
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
|
auto resultType = dyn_cast<ShapedType>(extract.getResult().getType());
|
||||||
if (!resultType || !resultType.hasStaticShape())
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
projectedCost = checkedAdd(projectedCost, static_cast<Weight>(getSizeInBytes(resultType)));
|
projectedCost = checkedAdd(projectedCost, static_cast<Cost>(getSizeInBytes(resultType)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (projectedCost == 0)
|
if (projectedCost == 0)
|
||||||
@@ -91,28 +124,286 @@ std::optional<Weight> getBatchProjectedInputTransferCost(SpatComputeBatch batch,
|
|||||||
return projectedCost;
|
return projectedCost;
|
||||||
}
|
}
|
||||||
|
|
||||||
Weight getInputTransferCost(const ComputeInstance& consumerInstance, Value input) {
|
Cost getInputTransferCost(const ComputeInstance& consumerInstance, Value input) {
|
||||||
auto inputType = cast<ShapedType>(input.getType());
|
auto inputType = cast<ShapedType>(input.getType());
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(consumerInstance.op))
|
if (auto batch = dyn_cast<SpatComputeBatch>(consumerInstance.op))
|
||||||
if (std::optional<Weight> projectedCost = getBatchProjectedInputTransferCost(batch, input))
|
if (std::optional<Cost> projectedCost = getBatchProjectedInputTransferCost(batch, input))
|
||||||
return *projectedCost;
|
return *projectedCost;
|
||||||
return static_cast<Weight>(getSizeInBytes(inputType));
|
return static_cast<Cost>(getSizeInBytes(inputType));
|
||||||
|
}
|
||||||
|
|
||||||
|
static CrossbarWeight getOpaqueCrossbarWeight(Value value, std::optional<uint32_t> lane) {
|
||||||
|
CrossbarWeight weight;
|
||||||
|
weight.opaqueValue = value;
|
||||||
|
weight.opaqueLane = lane.value_or(std::numeric_limits<uint32_t>::max());
|
||||||
|
return weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
|
||||||
|
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
|
||||||
|
return constant.getValue();
|
||||||
|
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
|
||||||
|
unsigned position = dim.getPosition();
|
||||||
|
if (position >= dims.size())
|
||||||
|
return failure();
|
||||||
|
return dims[position];
|
||||||
|
}
|
||||||
|
if (auto symbol = dyn_cast<AffineSymbolExpr>(expr)) {
|
||||||
|
unsigned position = symbol.getPosition();
|
||||||
|
if (position >= symbols.size())
|
||||||
|
return failure();
|
||||||
|
return symbols[position];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto binary = dyn_cast<AffineBinaryOpExpr>(expr);
|
||||||
|
if (!binary)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
FailureOr<int64_t> lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols);
|
||||||
|
FailureOr<int64_t> rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols);
|
||||||
|
if (failed(lhs) || failed(rhs))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto floorDiv = [](int64_t value, int64_t divisor) -> FailureOr<int64_t> {
|
||||||
|
if (divisor <= 0)
|
||||||
|
return failure();
|
||||||
|
if (value >= 0)
|
||||||
|
return value / divisor;
|
||||||
|
return -((-value + divisor - 1) / divisor);
|
||||||
|
};
|
||||||
|
|
||||||
|
switch (binary.getKind()) {
|
||||||
|
case AffineExprKind::Add: return *lhs + *rhs;
|
||||||
|
case AffineExprKind::Mul: return *lhs * *rhs;
|
||||||
|
case AffineExprKind::FloorDiv: return floorDiv(*lhs, *rhs);
|
||||||
|
case AffineExprKind::CeilDiv:
|
||||||
|
if (*rhs <= 0)
|
||||||
|
return failure();
|
||||||
|
return (*lhs + *rhs - 1) / *rhs;
|
||||||
|
case AffineExprKind::Mod: {
|
||||||
|
FailureOr<int64_t> div = floorDiv(*lhs, *rhs);
|
||||||
|
if (failed(div))
|
||||||
|
return failure();
|
||||||
|
return *lhs - *div * *rhs;
|
||||||
|
}
|
||||||
|
default: return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<int64_t>
|
||||||
|
evaluateIndexLike(Value value, const DenseMap<Value, int64_t>& bindings, std::optional<uint32_t> lane, Value laneArg);
|
||||||
|
|
||||||
|
static FailureOr<int64_t> evaluateIndexLike(OpFoldResult value,
|
||||||
|
const DenseMap<Value, int64_t>& bindings,
|
||||||
|
std::optional<uint32_t> lane,
|
||||||
|
Value laneArg) {
|
||||||
|
if (auto attr = llvm::dyn_cast<Attribute>(value)) {
|
||||||
|
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||||
|
if (!intAttr)
|
||||||
|
return failure();
|
||||||
|
return intAttr.getInt();
|
||||||
|
}
|
||||||
|
return evaluateIndexLike(llvm::cast<Value>(value), bindings, lane, laneArg);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<int64_t> evaluateIndexLike(Value value,
|
||||||
|
const DenseMap<Value, int64_t>& bindings,
|
||||||
|
std::optional<uint32_t> lane,
|
||||||
|
Value laneArg) {
|
||||||
|
if (lane && value == laneArg)
|
||||||
|
return *lane;
|
||||||
|
if (auto it = bindings.find(value); it != bindings.end())
|
||||||
|
return it->second;
|
||||||
|
|
||||||
|
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
|
||||||
|
return constant.value();
|
||||||
|
|
||||||
|
if (auto constant = value.getDefiningOp<arith::ConstantOp>())
|
||||||
|
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()))
|
||||||
|
return intAttr.getInt();
|
||||||
|
|
||||||
|
if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) {
|
||||||
|
auto constant = extract.getTensor().getDefiningOp<arith::ConstantOp>();
|
||||||
|
auto elements = constant ? dyn_cast<ElementsAttr>(constant.getValue()) : nullptr;
|
||||||
|
auto shapedType = elements ? dyn_cast<ShapedType>(elements.getType()) : nullptr;
|
||||||
|
if (!elements || !shapedType || shapedType.getRank() != 1 || extract.getIndices().size() != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
FailureOr<int64_t> index = evaluateIndexLike(extract.getIndices().front(), bindings, lane, laneArg);
|
||||||
|
if (failed(index) || *index < 0 || *index >= static_cast<int64_t>(elements.getNumElements()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (auto denseInts = dyn_cast<DenseIntElementsAttr>(elements))
|
||||||
|
return (*(denseInts.value_begin<APInt>() + *index)).getSExtValue();
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
|
||||||
|
if (!affineApply)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
AffineMap map = affineApply.getAffineMap();
|
||||||
|
if (map.getNumResults() != 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> operands;
|
||||||
|
operands.reserve(affineApply.getMapOperands().size());
|
||||||
|
for (Value operand : affineApply.getMapOperands()) {
|
||||||
|
FailureOr<int64_t> folded = evaluateIndexLike(operand, bindings, lane, laneArg);
|
||||||
|
if (failed(folded))
|
||||||
|
return failure();
|
||||||
|
operands.push_back(*folded);
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<int64_t> dims(operands.data(), map.getNumDims());
|
||||||
|
ArrayRef<int64_t> symbols(operands.data() + map.getNumDims(), map.getNumSymbols());
|
||||||
|
return evaluateAffineExpr(map.getResult(0), dims, symbols);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<SmallVector<int64_t, 4>>
|
||||||
|
evaluateIndexList(ArrayRef<OpFoldResult> values,
|
||||||
|
const DenseMap<Value, int64_t>& bindings,
|
||||||
|
std::optional<uint32_t> lane,
|
||||||
|
Value laneArg) {
|
||||||
|
SmallVector<int64_t, 4> result;
|
||||||
|
result.reserve(values.size());
|
||||||
|
for (OpFoldResult value : values) {
|
||||||
|
FailureOr<int64_t> folded = evaluateIndexLike(value, bindings, lane, laneArg);
|
||||||
|
if (failed(folded))
|
||||||
|
return failure();
|
||||||
|
result.push_back(*folded);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value resolveCrossbarWeightRoot(Operation* owner, Value root) {
|
||||||
|
if (auto arg = dyn_cast<BlockArgument>(root)) {
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(owner)) {
|
||||||
|
for (auto [index, operand] : llvm::enumerate(compute.getWeights()))
|
||||||
|
if (compute.getWeightArgument(index) == arg)
|
||||||
|
return operand;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto batch = dyn_cast<SpatComputeBatch>(owner)) {
|
||||||
|
for (auto [index, operand] : llvm::enumerate(batch.getWeights()))
|
||||||
|
if (batch.getWeightArgument(index) == arg)
|
||||||
|
return operand;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return root;
|
||||||
|
}
|
||||||
|
|
||||||
|
static CrossbarWeight completeCrossbarWeight(Value root,
|
||||||
|
SmallVector<int64_t, 4> offsets,
|
||||||
|
SmallVector<int64_t, 4> sizes,
|
||||||
|
SmallVector<int64_t, 4> strides) {
|
||||||
|
CrossbarWeight weight;
|
||||||
|
weight.root = root;
|
||||||
|
if (auto constant = root.getDefiningOp<arith::ConstantOp>())
|
||||||
|
weight.rootAttr = static_cast<Attribute>(constant.getValue());
|
||||||
|
weight.offsets = std::move(offsets);
|
||||||
|
weight.sizes = std::move(sizes);
|
||||||
|
weight.strides = std::move(strides);
|
||||||
|
return weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<CrossbarWeight>
|
||||||
|
getStaticCrossbarWeight(Operation* owner,
|
||||||
|
Value value,
|
||||||
|
const DenseMap<Value, int64_t>& bindings,
|
||||||
|
std::optional<uint32_t> lane,
|
||||||
|
Value laneArg) {
|
||||||
|
if (auto extract = value.getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||||
|
FailureOr<CrossbarWeight> sourceWeight =
|
||||||
|
getStaticCrossbarWeight(owner, extract.getSource(), bindings, lane, laneArg);
|
||||||
|
auto offsets = evaluateIndexList(extract.getMixedOffsets(), bindings, lane, laneArg);
|
||||||
|
auto sizes = evaluateIndexList(extract.getMixedSizes(), bindings, lane, laneArg);
|
||||||
|
auto strides = evaluateIndexList(extract.getMixedStrides(), bindings, lane, laneArg);
|
||||||
|
if (failed(sourceWeight) || failed(offsets) || failed(sizes) || failed(strides))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (sourceWeight->offsets.size() != offsets->size() || sourceWeight->sizes.size() != sizes->size()
|
||||||
|
|| sourceWeight->strides.size() != strides->size()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto [index, offset] : llvm::enumerate(*offsets)) {
|
||||||
|
sourceWeight->offsets[index] += offset * sourceWeight->strides[index];
|
||||||
|
sourceWeight->sizes[index] = (*sizes)[index];
|
||||||
|
sourceWeight->strides[index] *= (*strides)[index];
|
||||||
|
}
|
||||||
|
return *sourceWeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value root = resolveCrossbarWeightRoot(owner, value);
|
||||||
|
auto type = dyn_cast<ShapedType>(root.getType());
|
||||||
|
if (!type || !type.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> offsets(type.getRank(), 0);
|
||||||
|
SmallVector<int64_t, 4> sizes(type.getShape().begin(), type.getShape().end());
|
||||||
|
SmallVector<int64_t, 4> strides(type.getRank(), 1);
|
||||||
|
return completeCrossbarWeight(root, std::move(offsets), std::move(sizes), std::move(strides));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void addCrossbarWeight(CrossbarUsage& usage, CrossbarWeight weight) {
|
||||||
|
if (!containsCrossbarWeight(usage, weight))
|
||||||
|
usage.push_back(std::move(weight));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void collectCrossbarWeightsFromOp(Operation* op,
|
||||||
|
Operation* owner,
|
||||||
|
DenseMap<Value, int64_t>& bindings,
|
||||||
|
CrossbarUsage& usage,
|
||||||
|
Value laneArg,
|
||||||
|
std::optional<uint32_t> lane) {
|
||||||
|
if (auto loop = dyn_cast<scf::ForOp>(op)) {
|
||||||
|
auto lb = getConstantIntValue(loop.getLowerBound());
|
||||||
|
auto ub = getConstantIntValue(loop.getUpperBound());
|
||||||
|
auto step = getConstantIntValue(loop.getStep());
|
||||||
|
if (!lb || !ub || !step || *step <= 0)
|
||||||
|
return;
|
||||||
|
|
||||||
|
for (int64_t iv = *lb; iv < *ub; iv += *step) {
|
||||||
|
bindings[loop.getInductionVar()] = iv;
|
||||||
|
for (Operation& nested : loop.getBody()->without_terminator())
|
||||||
|
collectCrossbarWeightsFromOp(&nested, owner, bindings, usage, laneArg, lane);
|
||||||
|
}
|
||||||
|
bindings.erase(loop.getInductionVar());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto vmm = dyn_cast<SpatVMMOp>(op)) {
|
||||||
|
FailureOr<CrossbarWeight> weight = getStaticCrossbarWeight(owner, vmm.getWeight(), bindings, lane, laneArg);
|
||||||
|
if (failed(weight)) {
|
||||||
|
addCrossbarWeight(usage, getOpaqueCrossbarWeight(vmm.getWeight(), lane));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
addCrossbarWeight(usage, *weight);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Region& region : op->getRegions())
|
||||||
|
for (Block& block : region)
|
||||||
|
for (Operation& nested : block.without_terminator())
|
||||||
|
collectCrossbarWeightsFromOp(&nested, owner, bindings, usage, laneArg, lane);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
|
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
|
||||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
llvm::DenseMap<std::pair<size_t, size_t>, Cost> edgeCosts;
|
||||||
for (const ComputeGraphEdge& edge : edges) {
|
for (const ComputeGraphEdge& edge : edges) {
|
||||||
if (edge.source == edge.target)
|
if (edge.source == edge.target)
|
||||||
continue;
|
continue;
|
||||||
auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost);
|
auto inserted = edgeCosts.try_emplace({edge.source, edge.target}, edge.transferCost);
|
||||||
if (!inserted.second)
|
if (!inserted.second)
|
||||||
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
|
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ComputeGraphEdge> aggregatedEdges;
|
std::vector<ComputeGraphEdge> aggregatedEdges;
|
||||||
aggregatedEdges.reserve(edgeWeights.size());
|
aggregatedEdges.reserve(edgeCosts.size());
|
||||||
for (const auto& [key, weight] : edgeWeights)
|
for (const auto& [key, cost] : edgeCosts)
|
||||||
aggregatedEdges.push_back({key.first, key.second, weight});
|
aggregatedEdges.push_back({key.first, key.second, cost});
|
||||||
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge& lhs, const ComputeGraphEdge& rhs) {
|
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge& lhs, const ComputeGraphEdge& rhs) {
|
||||||
if (lhs.source != rhs.source)
|
if (lhs.source != rhs.source)
|
||||||
return lhs.source < rhs.source;
|
return lhs.source < rhs.source;
|
||||||
@@ -123,30 +414,75 @@ std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> ed
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
|
uint64_t countComputeBodyInstructions(Region& body) {
|
||||||
|
uint64_t numOperations = 0;
|
||||||
|
body.walk([&](Operation* op) { numOperations = checkedAdd(numOperations, static_cast<uint64_t>(1)); });
|
||||||
|
return numOperations;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t countComputeBodyOperationInstances(Region& body) {
|
||||||
|
uint64_t instances = 0;
|
||||||
|
for (Block& block : body)
|
||||||
|
for (Operation& op : block)
|
||||||
|
instances = checkedAdd(instances, countOperationInstances(op));
|
||||||
|
return instances;
|
||||||
|
}
|
||||||
|
|
||||||
|
CrossbarUsage collectDistinctCrossbarWeights(Operation* owner, std::optional<uint32_t> lane) {
|
||||||
|
CrossbarUsage usage;
|
||||||
|
DenseMap<Value, int64_t> bindings;
|
||||||
|
Value laneArg;
|
||||||
|
if (auto batch = dyn_cast<SpatComputeBatch>(owner))
|
||||||
|
if (auto maybeLaneArg = batch.getLaneArgument())
|
||||||
|
laneArg = *maybeLaneArg;
|
||||||
|
|
||||||
|
for (Region& region : owner->getRegions())
|
||||||
|
for (Block& block : region)
|
||||||
|
for (Operation& op : block.without_terminator())
|
||||||
|
collectCrossbarWeightsFromOp(&op, owner, bindings, usage, laneArg, lane);
|
||||||
|
return usage;
|
||||||
|
}
|
||||||
|
|
||||||
|
Cost getComputeInstanceCost(const ComputeInstance& instance) {
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||||
return getComputeBodyWeight(spatCompute.getBody());
|
return getComputeBodyCost(spatCompute.getBody());
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
return checkedMultiply(getComputeBodyCost(batch.getBody()), static_cast<Cost>(instance.laneCount));
|
||||||
}
|
}
|
||||||
|
|
||||||
CrossbarUsage getSpatComputeCrossbarUsage(SpatCompute spatComute){
|
bool containsCrossbarWeight(ArrayRef<CrossbarWeight> usage, const CrossbarWeight& weight) {
|
||||||
CrossbarUsage ret;
|
return llvm::is_contained(usage, weight);
|
||||||
ret.insert_range(spatComute.getWeights());
|
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
CrossbarUsage getSpatComputeBatchCrossbarUsage(SpatComputeBatch spatComuteBatch){
|
unsigned countCrossbarOverlap(ArrayRef<CrossbarWeight> lhs, ArrayRef<CrossbarWeight> rhs) {
|
||||||
CrossbarUsage ret;
|
unsigned overlap = 0;
|
||||||
ret.insert_range(spatComuteBatch.getWeights());
|
for (const CrossbarWeight& weight : rhs)
|
||||||
return ret;
|
if (containsCrossbarWeight(lhs, weight))
|
||||||
|
++overlap;
|
||||||
|
return overlap;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getCrossbarUnionSize(ArrayRef<CrossbarWeight> lhs, ArrayRef<CrossbarWeight> rhs) {
|
||||||
|
size_t size = lhs.size();
|
||||||
|
for (const CrossbarWeight& weight : rhs)
|
||||||
|
if (!containsCrossbarWeight(lhs, weight))
|
||||||
|
++size;
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
void insertCrossbarWeights(CrossbarUsage& usage, ArrayRef<CrossbarWeight> weights) {
|
||||||
|
for (const CrossbarWeight& weight : weights)
|
||||||
|
addCrossbarWeight(usage, weight);
|
||||||
}
|
}
|
||||||
|
|
||||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
|
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
CrossbarUsage usage;
|
||||||
return getSpatComputeCrossbarUsage(spatCompute);
|
if (isa<SpatCompute>(instance.op))
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
return collectDistinctCrossbarWeights(instance.op);
|
||||||
return getSpatComputeBatchCrossbarUsage(batch);
|
|
||||||
|
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||||
|
insertCrossbarWeights(usage, collectDistinctCrossbarWeights(instance.op, lane));
|
||||||
|
return usage;
|
||||||
}
|
}
|
||||||
|
|
||||||
ComputeGraph buildComputeGraph(Operation* entryOp) {
|
ComputeGraph buildComputeGraph(Operation* entryOp) {
|
||||||
@@ -161,7 +497,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
|
|||||||
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
|
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
|
||||||
size_t index = graph.nodes.size();
|
size_t index = graph.nodes.size();
|
||||||
graph.nodes.push_back(
|
graph.nodes.push_back(
|
||||||
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
{instance, getComputeInstanceCost(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||||
graph.instanceToIndex[instance] = index;
|
graph.instanceToIndex[instance] = index;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -173,7 +509,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
|
|||||||
ComputeInstance instance = getBatchChunkForIndex(batch, chunkIndex);
|
ComputeInstance instance = getBatchChunkForIndex(batch, chunkIndex);
|
||||||
size_t index = graph.nodes.size();
|
size_t index = graph.nodes.size();
|
||||||
graph.nodes.push_back(
|
graph.nodes.push_back(
|
||||||
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
{instance, getComputeInstanceCost(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||||
graph.instanceToIndex[instance] = index;
|
graph.instanceToIndex[instance] = index;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -185,7 +521,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
|
|||||||
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
||||||
llvm::SmallVector<Value, 4> inputs = getComputeInstanceInputs(node.instance);
|
llvm::SmallVector<Value, 4> inputs = getComputeInstanceInputs(node.instance);
|
||||||
for (Value input : inputs) {
|
for (Value input : inputs) {
|
||||||
Weight transferCost = getInputTransferCost(node.instance, input);
|
Cost transferCost = getInputTransferCost(node.instance, input);
|
||||||
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
|
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
|
||||||
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
|
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
|
||||||
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
|
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
|
||||||
@@ -208,7 +544,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<ComputeGraphEdge> aggregatedEdges = aggregateEdges(rawEdges);
|
std::vector<ComputeGraphEdge> aggregatedEdges = aggregateEdges(rawEdges);
|
||||||
graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end());
|
graph.edges.insert(graph.edges.end(), aggregatedEdges.begin(), aggregatedEdges.end());
|
||||||
graph.successors.assign(graph.nodes.size(), {});
|
graph.successors.assign(graph.nodes.size(), {});
|
||||||
graph.predecessors.assign(graph.nodes.size(), {});
|
graph.predecessors.assign(graph.nodes.size(), {});
|
||||||
for (const ComputeGraphEdge& edge : graph.edges) {
|
for (const ComputeGraphEdge& edge : graph.edges) {
|
||||||
@@ -233,8 +569,8 @@ bool verifyAcyclic(const ComputeGraph& graph) {
|
|||||||
size_t node = readyNodes.front();
|
size_t node = readyNodes.front();
|
||||||
readyNodes.pop();
|
readyNodes.pop();
|
||||||
++visited;
|
++visited;
|
||||||
for (const auto& [child, weight] : graph.successors[node]) {
|
for (const auto& [child, cost] : graph.successors[node]) {
|
||||||
(void) weight;
|
(void) cost;
|
||||||
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
||||||
if (--remainingParents[child] == 0)
|
if (--remainingParents[child] == 0)
|
||||||
readyNodes.push(child);
|
readyNodes.push(child);
|
||||||
|
|||||||
@@ -1,52 +1,72 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/SmallSet.h"
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <cstddef>
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "Utils.hpp"
|
|
||||||
#include "ComputeInstance.hpp"
|
#include "ComputeInstance.hpp"
|
||||||
#include "ComputeInstanceUtils.hpp"
|
#include "Utils.hpp"
|
||||||
|
|
||||||
using CrossbarUsage = llvm::SmallPtrSet<mlir::Value, 6>;
|
struct CrossbarWeight {
|
||||||
|
mlir::Value root;
|
||||||
|
mlir::Attribute rootAttr;
|
||||||
|
llvm::SmallVector<int64_t, 4> offsets;
|
||||||
|
llvm::SmallVector<int64_t, 4> sizes;
|
||||||
|
llvm::SmallVector<int64_t, 4> strides;
|
||||||
|
mlir::Value opaqueValue;
|
||||||
|
uint32_t opaqueLane = 0;
|
||||||
|
|
||||||
|
bool operator==(const CrossbarWeight& other) const {
|
||||||
|
bool sameRoot = rootAttr && other.rootAttr ? rootAttr == other.rootAttr : root == other.root;
|
||||||
|
return sameRoot && offsets == other.offsets && sizes == other.sizes && strides == other.strides
|
||||||
|
&& opaqueValue == other.opaqueValue && opaqueLane == other.opaqueLane;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using CrossbarUsage = llvm::SmallVector<CrossbarWeight, 6>;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace spatial {
|
namespace spatial {
|
||||||
|
|
||||||
struct ComputeGraphNode {
|
struct ComputeGraphNode {
|
||||||
ComputeInstance instance;
|
ComputeInstance instance;
|
||||||
Weight weight = 0;
|
Cost cost = 0;
|
||||||
llvm::SmallPtrSet<mlir::Value,6> crossbarUsage;
|
CrossbarUsage crossbarUsage;
|
||||||
size_t originalOrder = 0;
|
size_t originalOrder = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ComputeGraphEdge {
|
struct ComputeGraphEdge {
|
||||||
size_t source = 0;
|
size_t source = 0;
|
||||||
size_t target = 0;
|
size_t target = 0;
|
||||||
Weight transferCost = 0;
|
Cost transferCost = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ComputeGraph {
|
struct ComputeGraph {
|
||||||
llvm::SmallVector<ComputeGraphNode> nodes;
|
std::vector<ComputeGraphNode> nodes;
|
||||||
llvm::SmallVector<ComputeGraphEdge> edges;
|
std::vector<ComputeGraphEdge> edges;
|
||||||
std::vector<std::vector<std::pair<size_t, Weight>>> successors;
|
std::vector<std::vector<std::pair<size_t, Cost>>> successors;
|
||||||
std::vector<std::vector<std::pair<size_t, Weight>>> predecessors;
|
std::vector<std::vector<std::pair<size_t, Cost>>> predecessors;
|
||||||
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
||||||
};
|
};
|
||||||
|
|
||||||
ComputeGraph buildComputeGraph(mlir::Operation* entryOp);
|
ComputeGraph buildComputeGraph(mlir::Operation* entryOp);
|
||||||
bool verifyAcyclic(const ComputeGraph& graph);
|
bool verifyAcyclic(const ComputeGraph& graph);
|
||||||
|
|
||||||
Weight getComputeInstanceWeight(const ComputeInstance& instance);
|
uint64_t countComputeBodyInstructions(mlir::Region& body);
|
||||||
|
uint64_t countComputeBodyOperationInstances(mlir::Region& body);
|
||||||
|
Cost getComputeInstanceCost(const ComputeInstance& instance);
|
||||||
|
CrossbarUsage collectDistinctCrossbarWeights(mlir::Operation* owner, std::optional<uint32_t> lane = std::nullopt);
|
||||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance);
|
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance);
|
||||||
|
bool containsCrossbarWeight(llvm::ArrayRef<CrossbarWeight> usage, const CrossbarWeight& weight);
|
||||||
|
unsigned countCrossbarOverlap(llvm::ArrayRef<CrossbarWeight> lhs, llvm::ArrayRef<CrossbarWeight> rhs);
|
||||||
|
size_t getCrossbarUnionSize(llvm::ArrayRef<CrossbarWeight> lhs, llvm::ArrayRef<CrossbarWeight> rhs);
|
||||||
|
void insertCrossbarWeights(CrossbarUsage& usage, llvm::ArrayRef<CrossbarWeight> weights);
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
+4
-4
@@ -48,12 +48,12 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result
|
|||||||
return lhs.second < rhs.second;
|
return lhs.second < rhs.second;
|
||||||
});
|
});
|
||||||
|
|
||||||
unsigned int usedCrossbars = 0;
|
CrossbarUsage usedCrossbars;
|
||||||
for (size_t slot = 0; slot < scheduledTasks.size(); ++slot) {
|
for (size_t slot = 0; slot < scheduledTasks.size(); ++slot) {
|
||||||
if (scheduledTasks[slot].first != slot)
|
if (scheduledTasks[slot].first != slot)
|
||||||
llvm::report_fatal_error("merge scheduling: CPU slots are not contiguous");
|
llvm::report_fatal_error("merge scheduling: CPU slots are not contiguous");
|
||||||
usedCrossbars = addOrMax(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage.size());
|
insertCrossbarWeights(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage);
|
||||||
if (usedCrossbars > crossbarCapacity)
|
if (usedCrossbars.size() > crossbarCapacity)
|
||||||
llvm::report_fatal_error("merge scheduling: CPU crossbar capacity exceeded");
|
llvm::report_fatal_error("merge scheduling: CPU crossbar capacity exceeded");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,7 +77,7 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result
|
|||||||
if (sourceCpu == targetCpu && sourceSlot >= targetSlot)
|
if (sourceCpu == targetCpu && sourceSlot >= targetSlot)
|
||||||
llvm::report_fatal_error("merge scheduling: same-CPU dependency order is invalid");
|
llvm::report_fatal_error("merge scheduling: same-CPU dependency order is invalid");
|
||||||
|
|
||||||
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].weight);
|
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].cost);
|
||||||
if (sourceCpu != targetCpu)
|
if (sourceCpu != targetCpu)
|
||||||
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
|
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
|
||||||
if (targetStart < earliestTargetStart) {
|
if (targetStart < earliestTargetStart) {
|
||||||
|
|||||||
@@ -89,8 +89,8 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
||||||
|
|
||||||
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
||||||
// If graph.nodes[task] is modified to hold a vector of weights per processor, access it here.
|
// If graph.nodes[task] is modified to hold a vector of costs per processor, access it here.
|
||||||
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].weight; };
|
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].cost; };
|
||||||
|
|
||||||
std::vector<Time> oct(nodeCount * processorCount, 0);
|
std::vector<Time> oct(nodeCount * processorCount, 0);
|
||||||
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
||||||
@@ -163,7 +163,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<char> scheduled(nodeCount, false);
|
std::vector<char> scheduled(nodeCount, false);
|
||||||
std::vector<CrossbarUsage> processorCrossbars(processorCount, llvm::SmallPtrSet<mlir::Value, 6> {});
|
std::vector<CrossbarUsage> processorCrossbars(processorCount);
|
||||||
std::vector<ScheduledTask> schedules(nodeCount);
|
std::vector<ScheduledTask> schedules(nodeCount);
|
||||||
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
|
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
|
||||||
|
|
||||||
@@ -178,25 +178,17 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
Time bestEst = 0;
|
Time bestEst = 0;
|
||||||
Time bestEft = 0;
|
Time bestEft = 0;
|
||||||
Time bestOeft = std::numeric_limits<Time>::max();
|
Time bestOeft = std::numeric_limits<Time>::max();
|
||||||
unsigned int bestOverlapWeight = 0;
|
unsigned int bestOverlapCount = 0;
|
||||||
bool crossbarRejected = false;
|
bool crossbarRejected = false;
|
||||||
|
|
||||||
auto crossbarsAreContainedInProcessor = [&processorCrossbars](mlir::Value nodeCrossbar, size_t processor) {
|
|
||||||
return llvm::is_contained(processorCrossbars[processor], nodeCrossbar);
|
|
||||||
};
|
|
||||||
|
|
||||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
auto crossbarsAreContained = std::bind(crossbarsAreContainedInProcessor, std::placeholders::_1, processor);
|
unsigned int overlapCount = countCrossbarOverlap(processorCrossbars[processor], graph.nodes[task].crossbarUsage);
|
||||||
if (graph.nodes[task].crossbarUsage.size() != 0
|
if (!graph.nodes[task].crossbarUsage.empty()
|
||||||
&& !llvm::all_of(graph.nodes[task].crossbarUsage, crossbarsAreContained)
|
&& getCrossbarUnionSize(processorCrossbars[processor], graph.nodes[task].crossbarUsage)
|
||||||
&& addOrMax(processorCrossbars[processor].size(), graph.nodes[task].crossbarUsage.size())
|
|
||||||
> options.crossbarCapacity) {
|
> options.crossbarCapacity) {
|
||||||
|
|
||||||
crossbarRejected = true;
|
crossbarRejected = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
unsigned int overlapWeight =
|
|
||||||
llvm::count_if(graph.nodes[task].crossbarUsage, crossbarsAreContained);
|
|
||||||
|
|
||||||
Time dataReady = 0;
|
Time dataReady = 0;
|
||||||
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
||||||
@@ -206,7 +198,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. PEFT Gap-Filling EST Calculation (Maintains optimal scheduling math)
|
// 2. PEFT Gap-Filling EST Calculation (Maintains optimal scheduling math)
|
||||||
Time compWeight = getComputeCost(task, processor);
|
Time computeCost = getComputeCost(task, processor);
|
||||||
Time est = dataReady;
|
Time est = dataReady;
|
||||||
Time currentEnd = 0;
|
Time currentEnd = 0;
|
||||||
bool foundGap = false;
|
bool foundGap = false;
|
||||||
@@ -215,7 +207,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
const ScheduledTask& schedTask = schedules[schedTaskIndex];
|
const ScheduledTask& schedTask = schedules[schedTaskIndex];
|
||||||
Time gapStart = std::max(currentEnd, dataReady);
|
Time gapStart = std::max(currentEnd, dataReady);
|
||||||
|
|
||||||
if (addOrMax(gapStart, compWeight) <= schedTask.startTime) {
|
if (addOrMax(gapStart, computeCost) <= schedTask.startTime) {
|
||||||
est = gapStart;
|
est = gapStart;
|
||||||
foundGap = true;
|
foundGap = true;
|
||||||
break;
|
break;
|
||||||
@@ -226,7 +218,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
if (!foundGap)
|
if (!foundGap)
|
||||||
est = std::max(currentEnd, dataReady);
|
est = std::max(currentEnd, dataReady);
|
||||||
|
|
||||||
Time eft = addOrMax(est, compWeight);
|
Time eft = addOrMax(est, computeCost);
|
||||||
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||||
|
|
||||||
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
||||||
@@ -235,14 +227,14 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
bestEst = est;
|
bestEst = est;
|
||||||
bestEft = eft;
|
bestEft = eft;
|
||||||
bestOeft = oeft;
|
bestOeft = oeft;
|
||||||
bestOverlapWeight = overlapWeight;
|
bestOverlapCount = overlapCount;
|
||||||
}
|
}
|
||||||
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapWeight < bestOverlapWeight) {
|
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapCount < bestOverlapCount) {
|
||||||
bestProcessor = processor;
|
bestProcessor = processor;
|
||||||
bestEst = est;
|
bestEst = est;
|
||||||
bestEft = eft;
|
bestEft = eft;
|
||||||
bestOeft = oeft;
|
bestOeft = oeft;
|
||||||
bestOverlapWeight = overlapWeight;
|
bestOverlapCount = overlapCount;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -265,7 +257,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
schedules[task] = {bestProcessor, bestEst, bestEft};
|
schedules[task] = {bestProcessor, bestEst, bestEft};
|
||||||
scheduled[task] = true;
|
scheduled[task] = true;
|
||||||
++scheduledCount;
|
++scheduledCount;
|
||||||
processorCrossbars[bestProcessor].insert_range(graph.nodes[task].crossbarUsage);
|
insertCrossbarWeights(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
|
||||||
|
|
||||||
// 3. CRITICAL FIX: Topological Append
|
// 3. CRITICAL FIX: Topological Append
|
||||||
// Because the readyQueue pops in strict topological order, simply pushing to the
|
// Because the readyQueue pops in strict topological order, simply pushing to the
|
||||||
@@ -319,37 +311,6 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/*{
|
|
||||||
llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n";
|
|
||||||
std::vector<bool> visited(processorCount, false);
|
|
||||||
size_t uniqueClassCount = 0;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < processorCount; ++i) {
|
|
||||||
if (visited[i])
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// We found a new unique schedule (equivalence class)
|
|
||||||
++uniqueClassCount;
|
|
||||||
visited[i] = true;
|
|
||||||
|
|
||||||
llvm::dbgs() << "Class " << uniqueClassCount << ": CPUs { " << i;
|
|
||||||
|
|
||||||
// Find and mark all identical companions
|
|
||||||
auto it = equivalentClass.find(i);
|
|
||||||
if (it != equivalentClass.end()) {
|
|
||||||
for (size_t eqCpu : it->second) {
|
|
||||||
if (!visited[eqCpu]) {
|
|
||||||
llvm::dbgs() << ", " << eqCpu;
|
|
||||||
visited[eqCpu] = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
llvm::dbgs() << " }\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n";
|
|
||||||
llvm::dbgs() << "--------------------------------------\n";
|
|
||||||
}*/
|
|
||||||
|
|
||||||
// 6. Populate Final Result
|
// 6. Populate Final Result
|
||||||
MergeScheduleResult result;
|
MergeScheduleResult result;
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using CPU = int;
|
using CPU = int;
|
||||||
using Weight = unsigned long long;
|
using Cost = unsigned long long;
|
||||||
using Time = unsigned long long;
|
using Time = unsigned long long;
|
||||||
|
|
||||||
|
|
||||||
@@ -55,4 +55,3 @@ inline T subtractOrZero(T lhs, T rhs) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); }
|
inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); }
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user