automatic code reformat

This commit is contained in:
NiccoloN
2026-05-22 15:23:48 +02:00
parent d136136d22
commit 8337a11ce9
37 changed files with 312 additions and 354 deletions
+3 -3
View File
@@ -13,7 +13,8 @@
namespace onnx_mlir::pim {
struct CappedDiagnosticReporter {
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8)
: maxReportedFailures(maxReportedFailures) {}
template <typename EmitFn>
void report(mlir::Operation* op, EmitFn&& emit) {
@@ -24,8 +25,7 @@ struct CappedDiagnosticReporter {
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
if (numFailures > maxReportedFailures)
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
<< failureDescription;
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
}
bool hasFailure() const { return numFailures != 0; }
+9 -9
View File
@@ -1,7 +1,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#define DEBUG_TYPE "PimCompilerOptions"
namespace onnx_mlir {
@@ -15,13 +15,13 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
llvm::cl::init(EmitPimCodegen),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
"pim-merge-scheduler",
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
llvm::cl::init(MergeSchedulerPeft),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMergeSchedulerType>
pimMergeScheduler("pim-merge-scheduler",
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
llvm::cl::init(MergeSchedulerPeft),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool>
pimOnlyCodegen("pim-only-codegen",
@@ -99,15 +99,17 @@ auto createSpatCompute(RewriterT& rewriter,
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithValues(
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
detail::invokeWithValues(std::forward<BodyFn>(body),
detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult = detail::invokeWithValues(
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
@@ -423,8 +423,11 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlices[coreId].size());
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
vmmOutputs.push_back(spatial::SpatVMMOp::create(
rewriter, gemmLoc, currOutHSliceType, computeOp.getWeightArgument(aHSliceId), computeOp.getInputArgument(aHSliceId)));
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter,
gemmLoc,
currOutHSliceType,
computeOp.getWeightArgument(aHSliceId),
computeOp.getInputArgument(aHSliceId)));
if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure();
@@ -579,8 +582,8 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
tensor::ParallelInsertSliceOp::create(rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes,
unitStrides);
tensor::ParallelInsertSliceOp::create(
rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, unitStrides);
rewriter.setInsertionPointAfter(batchOp);
rewriter.replaceOp(gemmOp, batchOp.getResults());
@@ -38,23 +38,16 @@ static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
}
static Value collapseBatchDims(Value value,
int64_t batchSize,
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
static Value
collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2 || type.getRank() == 3)
return value;
auto collapsedType =
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
};
auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}};
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
reassociation.front().push_back(dim);
@@ -72,19 +65,14 @@ static Value collapseBatchDims(Value value,
return collapseCompute.getResult(0);
}
static Value expandBatchDims(Value value,
RankedTensorType outputType,
size_t batchRank,
PatternRewriter& rewriter,
Location loc) {
static Value
expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) {
if (cast<RankedTensorType>(value.getType()) == outputType)
return value;
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(batchRank)},
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
};
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(batchRank)},
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}};
for (size_t dim = 0; dim < batchRank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
@@ -58,24 +58,21 @@ static Value buildNearestResizeLoop(Value input,
Value outputC = channelLoop.getInductionVar();
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
Value inputC =
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
rewriter.setInsertionPointToStart(heightLoop.getBody());
Value outputH = heightLoop.getInductionVar();
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
Value inputH =
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
rewriter.setInsertionPointToStart(widthLoop.getBody());
Value outputW = widthLoop.getInductionVar();
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
Value inputW =
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
Value inputSlice =
@@ -114,8 +111,8 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|| resizeOp.getNearestMode() != "floor")
return rewriter.notifyMatchFailure(
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
return rewriter.notifyMatchFailure(resizeOp,
"resize lowering currently supports only nearest + asymmetric + floor.");
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
@@ -94,8 +94,8 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
}
llvm::append_range(newBlockArgTypes, newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
@@ -228,7 +228,8 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
mapper.map(oldArg, *clonedValue);
}
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults()))
mapper.map(compute.getOutputArgument(resultIndex), newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
mapper.map(compute.getOutputArgument(resultIndex),
newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
for (Operation& op : oldBlock)
rewriter.clone(op, mapper);
@@ -1,8 +1,8 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
@@ -141,8 +141,8 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
}
totalOffset = totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult()
: scaledOffset;
totalOffset =
totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() : scaledOffset;
}
if (!totalOffset)
@@ -30,8 +30,8 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
Value replacement) {
Block& body = owner->getRegion(0).front();
BlockArgument bodyArgument = isa<spatial::SpatCompute>(owner)
? cast<spatial::SpatCompute>(owner).getInputArgument(inputIndex)
: cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
? cast<spatial::SpatCompute>(owner).getInputArgument(inputIndex)
: cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
unsigned bodyArgIndex = bodyArgument.getArgNumber();
rewriter.startOpModification(owner);
@@ -233,10 +233,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
if (!computeOp.getWeights().empty())
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
rewriter.setInsertionPointAfter(computeOp);
auto coreOp = PimCoreOp::create(rewriter,
loc,
ValueRange(computeWeights),
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
auto coreOp = PimCoreOp::create(
rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
rewriter.setInsertionPointToStart(&block);
auto& coreOpBlocks = coreOp.getBody().getBlocks();
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
@@ -217,8 +217,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
} // namespace
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(
patterns.getContext());
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -546,11 +546,8 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
return ReturnPathLoweringResult::NotReturnPath;
}
raptor::SpatialToPimPass::ReturnPathLoweringResult
raptor::SpatialToPimPass::lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
OpResult result,
Value yieldValue,
IRRewriter& rewriter) {
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
}
@@ -318,7 +318,7 @@ void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp f
}
LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
IRRewriter& rewriter) {
IRRewriter& rewriter) {
Location loc = funcOp.getLoc();
OperationFolder constantFolder(funcOp.getContext());
@@ -39,13 +39,10 @@ private:
size_t coreId = 0;
llvm::SmallVector<mlir::Operation*> operationsToRemove;
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp,
mlir::IRRewriter& rewriter);
mlir::LogicalResult lowerComputeOp(spatial::SpatCompute computeOp,
mlir::IRRewriter& rewriter,
mlir::OperationFolder& constantFolder);
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
mlir::IRRewriter& rewriter);
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
mlir::LogicalResult
lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder);
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter);
enum class ReturnPathLoweringResult {
Handled,
+2 -2
View File
@@ -1,7 +1,7 @@
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include <string>
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
+4 -5
View File
@@ -56,7 +56,8 @@ static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<O
return parser.parseRParen();
}
static void printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
static void
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
printCompressedValueList(printer, arguments, delimiter);
printer << " = ";
printCompressedValueList(printer, operands, delimiter);
@@ -82,10 +83,8 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
switch (currentDelimiter) {
case ListDelimiter::Paren:
return parser.parseRParen();
case ListDelimiter::Square:
return parser.parseRSquare();
case ListDelimiter::Paren: return parser.parseRParen();
case ListDelimiter::Square: return parser.parseRSquare();
}
llvm_unreachable("unsupported delimiter");
};
+7 -10
View File
@@ -1,7 +1,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
@@ -66,8 +66,8 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
|| isExplicitHostOperand(op, operand.getOperandNumber()))
continue;
InFlightDiagnostic diagnostic =
ownerOp->emitOpError() << kind << " body may only directly reference external constants";
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
<< kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc())
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
hasFailure = true;
@@ -153,10 +153,9 @@ LogicalResult PimCoreOp::verify() {
Block& block = getBody().front();
if (block.getNumArguments() != getWeights().size())
return emitError("core body must have one block argument per weight");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core weight block argument types must match weight operand types exactly");
}
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
}
@@ -169,14 +168,12 @@ LogicalResult PimCoreBatchOp::verify() {
return emitError("core_batch body must have lane, weight, and input block arguments");
if (!getLaneArgument().getType().isIndex())
return emitError("core_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core_batch weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
if (getInputArgument(inputIndex).getType() != input.getType())
return emitError("core_batch input block argument types must match input operand types exactly");
}
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
}
@@ -50,10 +50,9 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
pendingValues.push_back(result);
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs()))
if (initArg == value)
pendingValues.push_back(forOp.getResult(index));
}
}
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
+3 -5
View File
@@ -1,7 +1,7 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include <string>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
@@ -65,9 +65,7 @@ void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() {
return getRegion().front().getOperations();
}
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
void SpatialDialect::initialize() {
addTypes<
+3 -6
View File
@@ -104,16 +104,13 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
return failure();
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
switch (currentDelimiter) {
case ListDelimiter::Paren:
return parser.parseRParen();
case ListDelimiter::Square:
return parser.parseRSquare();
case ListDelimiter::Paren: return parser.parseRParen();
case ListDelimiter::Square: return parser.parseRSquare();
}
llvm_unreachable("unsupported delimiter");
};
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) {
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
return failure();
}
return success();
}
+6 -9
View File
@@ -272,9 +272,9 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
continue;
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
<< kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber()
<< " is used by " << op->getName();
<< kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc())
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
hasFailure = true;
}
});
@@ -457,14 +457,12 @@ LogicalResult SpatCompute::verify() {
if (block.getNumArguments() != expectedArgCount)
return emitError("compute body must have weight and input block arguments");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("compute weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
if (getInputArgument(inputIndex).getType() != input.getType())
return emitError("compute input block argument types must match input operand types exactly");
}
if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
@@ -582,10 +580,9 @@ LogicalResult SpatComputeBatch::verify() {
if (!getLaneArgument().getType().isIndex())
return emitError("compute_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("compute_batch weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
BlockArgument blockArg = getInputArgument(inputIndex);
if (blockArg.getType() != input.getType())
@@ -1,6 +1,6 @@
#include "DCPAnalysis.hpp"
#include "../Scheduling/ComputeGraph.hpp"
#include "../Scheduling/DcpScheduler.hpp"
#include "DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
namespace onnx_mlir {
@@ -11,15 +11,15 @@ using DCPAnalysisResult = MergeScheduleResult;
struct DCPAnalysis {
private:
DCPAnalysisResult result;
mlir::Operation *entryOp;
mlir::Operation* entryOp;
DCPAnalysisResult run();
public:
DCPAnalysis(mlir::Operation *op)
DCPAnalysis(mlir::Operation* op)
: entryOp(op) {
result = run();
}
DCPAnalysisResult &getResult() { return result; }
DCPAnalysisResult& getResult() { return result; }
};
} // namespace spatial
@@ -1,5 +1,3 @@
#include "MaterializeMergeSchedule.hpp"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -24,6 +22,7 @@
#include <utility>
#include <vector>
#include "MaterializeMergeSchedule.hpp"
#include "Scheduling/ComputeInstanceUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -72,9 +71,7 @@ struct CpuSlotKey {
};
struct CpuSlotKeyInfo {
static CpuSlotKey getEmptyKey() {
return {std::numeric_limits<CpuId>::max(), std::numeric_limits<SlotId>::max()};
}
static CpuSlotKey getEmptyKey() { return {std::numeric_limits<CpuId>::max(), std::numeric_limits<SlotId>::max()}; }
static CpuSlotKey getTombstoneKey() {
return {std::numeric_limits<CpuId>::max() - 1, std::numeric_limits<SlotId>::max()};
@@ -139,11 +136,12 @@ struct MaterializerState {
DenseMap<Value, Value> hostReplacements;
DenseSet<Operation*> oldComputeOps;
MaterializerState(func::FuncOp func,
const MergeScheduleResult& schedule,
int64_t& nextChannelId)
: func(func), schedule(schedule), rewriter(func.getContext()), constantFolder(func.getContext()),
nextChannelId(nextChannelId) {}
MaterializerState(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId)
: func(func),
schedule(schedule),
rewriter(func.getContext()),
constantFolder(func.getContext()),
nextChannelId(nextChannelId) {}
};
bool isConstantLike(Value value) {
@@ -189,11 +187,12 @@ std::optional<uint32_t> getConstantFirstSliceOffset(tensor::ExtractSliceOp extra
return std::nullopt;
}
ProducerKey getBatchLaneProducerKey(SpatComputeBatch batch,
uint32_t laneStart,
uint32_t laneCount,
size_t resultIndex) {
return {{batch.getOperation(), laneStart, laneCount}, resultIndex};
ProducerKey
getBatchLaneProducerKey(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount, size_t resultIndex) {
return {
{batch.getOperation(), laneStart, laneCount},
resultIndex
};
}
ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex) {
@@ -202,8 +201,8 @@ ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex)
bool isWholeBatchProducerKey(ProducerKey key) {
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0 &&
key.instance.laneCount == static_cast<uint32_t>(batch.getLaneCount());
return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0
&& key.instance.laneCount == static_cast<uint32_t>(batch.getLaneCount());
}
SmallVector<ProducerKey, 8> expandWholeBatchProducerKey(ProducerKey key) {
@@ -306,7 +305,10 @@ std::optional<ProducerKey> getProducerKey(Value value, const ComputeInstance* co
auto result = dyn_cast<OpResult>(value);
if (!result)
return std::nullopt;
return ProducerKey {{compute.getOperation(), 0, 1}, result.getResultNumber()};
return ProducerKey {
{compute.getOperation(), 0, 1},
result.getResultNumber()
};
}
if (auto batch = dyn_cast<SpatComputeBatch>(definingOp)) {
@@ -316,10 +318,8 @@ std::optional<ProducerKey> getProducerKey(Value value, const ComputeInstance* co
if (batch.getNumResults() != 0) {
if (consumerInstance && isa<SpatComputeBatch>(consumerInstance->op))
return getBatchLaneProducerKey(batch,
consumerInstance->laneStart,
consumerInstance->laneCount,
result.getResultNumber());
return getBatchLaneProducerKey(
batch, consumerInstance->laneStart, consumerInstance->laneCount, result.getResultNumber());
return getWholeBatchProducerKey(batch, result.getResultNumber());
}
@@ -489,12 +489,13 @@ void createEmptyMaterializedOps(MaterializerState& state) {
continue;
}
auto batch = SpatComputeBatch::create(state.rewriter,
loc,
TypeRange(resultTypes),
state.rewriter.getI32IntegerAttr(static_cast<int32_t>(materializedClass.cpus.size())),
ValueRange {},
ValueRange {});
auto batch =
SpatComputeBatch::create(state.rewriter,
loc,
TypeRange(resultTypes),
state.rewriter.getI32IntegerAttr(static_cast<int32_t>(materializedClass.cpus.size())),
ValueRange {},
ValueRange {});
batch.getProperties().setOperandSegmentSizes({0, 0});
SmallVector<int32_t> coreIds;
coreIds.reserve(materializedClass.cpus.size());
@@ -506,15 +507,13 @@ void createEmptyMaterializedOps(MaterializerState& state) {
SmallVector<Location, 4> blockArgLocs {loc};
llvm::append_range(blockArgTypes, resultTypes);
blockArgLocs.append(resultTypes.size(), loc);
Block* body = state.rewriter.createBlock(
&batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
Block* body =
state.rewriter.createBlock(&batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
state.rewriter.setInsertionPointToEnd(body);
if (resultTypes.empty()) {
if (resultTypes.empty())
SpatYieldOp::create(state.rewriter, loc, ValueRange {});
}
else {
else
SpatInParallelOp::create(state.rewriter, loc);
}
materializedClass.op = batch.getOperation();
materializedClass.body = body;
state.rewriter.setInsertionPointAfter(batch.getOperation());
@@ -559,7 +558,8 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali
else {
cast<SpatComputeBatch>(materializedClass.op).getInputsMutable().append(ValueRange(input));
setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size());
BlockArgument arg = materializedClass.body->insertArgument(materializedClass.body->getNumArguments()-1, input.getType(), input.getLoc());
BlockArgument arg = materializedClass.body->insertArgument(
materializedClass.body->getNumArguments() - 1, input.getType(), input.getLoc());
materializedClass.inputArgs[input] = arg;
return arg;
}
@@ -586,9 +586,8 @@ SmallVector<Value, 8> createIndexConstants(MaterializerState& state, Operation*
return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened));
}
FailureOr<SmallVector<ComputeInstance, 8>> getPeerInstances(MaterializerState& state,
const MaterializedClass& materializedClass,
SlotId slot) {
FailureOr<SmallVector<ComputeInstance, 8>>
getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) {
SmallVector<ComputeInstance, 8> peers;
peers.reserve(materializedClass.cpus.size());
for (CpuId cpu : materializedClass.cpus) {
@@ -601,9 +600,9 @@ FailureOr<SmallVector<ComputeInstance, 8>> getPeerInstances(MaterializerState& s
}
Value createOriginalLaneValue(MaterializerState& state,
MaterializedClass& materializedClass,
ArrayRef<ComputeInstance> peers,
Location loc) {
MaterializedClass& materializedClass,
ArrayRef<ComputeInstance> peers,
Location loc) {
assert(!peers.empty() && "expected at least one peer instance");
if (!materializedClass.isBatch)
return createIndexConstant(state, materializedClass.op, peers.front().laneStart);
@@ -751,12 +750,12 @@ SmallVector<ClassId, 4> getSortedDestinationClasses(MaterializerState& state, Pr
}
Value appendReceive(MaterializerState& state,
MaterializedClass& targetClass,
Type type,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
MaterializedClass& targetClass,
Type type,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, targetClass.op, channelIds);
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, targetClass.op, sourceCoreIds);
@@ -774,22 +773,18 @@ Value appendReceive(MaterializerState& state,
.getOutput();
}
return SpatChannelReceiveOp::create(state.rewriter,
loc,
type,
channelIdValues.front(),
sourceCoreIdValues.front(),
targetCoreIdValues.front())
return SpatChannelReceiveOp::create(
state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
.getOutput();
}
Value appendHostReceive(MaterializerState& state,
MaterializedClass& sourceClass,
Type type,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
MaterializedClass& sourceClass,
Type type,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
state.rewriter.setInsertionPointAfter(sourceClass.op);
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, sourceClass.op, channelIds);
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds);
@@ -802,19 +797,13 @@ Value appendHostReceive(MaterializerState& state,
}
assert(channelIds.size() == 1 && "scalar host receive expects one channel");
return SpatChannelReceiveOp::create(state.rewriter,
loc,
type,
channelIdValues.front(),
sourceCoreIdValues.front(),
targetCoreIdValues.front())
return SpatChannelReceiveOp::create(
state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
.getOutput();
}
LogicalResult setHostOutputValue(MaterializerState& state,
MaterializedClass& sourceClass,
Value originalOutput,
Value payload) {
LogicalResult
setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) {
auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput);
if (resultIt == sourceClass.hostOutputToResultIndex.end())
return sourceClass.op->emitError("missing host result slot for materialized output");
@@ -864,12 +853,12 @@ LogicalResult setHostOutputValue(MaterializerState& state,
}
void appendScalarSend(MaterializerState& state,
MaterializedClass& sourceClass,
Value payload,
int64_t channelId,
int32_t sourceCoreId,
int32_t targetCoreId,
Location loc) {
MaterializedClass& sourceClass,
Value payload,
int64_t channelId,
int32_t sourceCoreId,
int32_t targetCoreId,
Location loc) {
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId);
Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId);
@@ -878,12 +867,12 @@ void appendScalarSend(MaterializerState& state,
}
void appendBatchSend(MaterializerState& state,
MaterializedClass& sourceClass,
Value payload,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
MaterializedClass& sourceClass,
Value payload,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
Location loc) {
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, sourceClass.op, channelIds);
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds);
@@ -892,11 +881,11 @@ void appendBatchSend(MaterializerState& state,
}
LogicalResult emitClassToClassCommunication(MaterializerState& state,
MaterializedClass& sourceClass,
MaterializedClass& targetClass,
ArrayRef<ProducerKey> keys,
Value payload,
Location loc) {
MaterializedClass& sourceClass,
MaterializedClass& targetClass,
ArrayRef<ProducerKey> keys,
Value payload,
Location loc) {
if (sourceClass.id == targetClass.id) {
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = payload;
@@ -908,8 +897,13 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
appendScalarSend(state, sourceClass, payload, channelId, sourceCpu, targetCpu, loc);
Value received = appendReceive(state, targetClass, payload.getType(), ArrayRef<int64_t>(channelId),
ArrayRef<int32_t>(sourceCpu), ArrayRef<int32_t>(targetCpu), loc);
Value received = appendReceive(state,
targetClass,
payload.getType(),
ArrayRef<int64_t>(channelId),
ArrayRef<int32_t>(sourceCpu),
ArrayRef<int32_t>(targetCpu),
loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
@@ -937,7 +931,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
loc);
}
Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
Value received =
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
@@ -946,8 +941,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
if (sourceClass.isBatch && !targetClass.isBatch) {
std::optional<ProducerKey> packedKey = getContiguousProducerKeyForKeys(keys);
if (!packedKey)
return sourceClass.op->emitError(
"cannot materialize batch-to-scalar communication as concat because source lanes are not contiguous in send order");
return sourceClass.op->emitError("cannot materialize batch-to-scalar communication as concat because source "
"lanes are not contiguous in send order");
FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(payload.getType(), keys.size());
if (failed(packedType))
@@ -975,7 +970,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
if (sourceClass.isBatch && targetClass.isBatch) {
if (sourceClass.cpus.size() != targetClass.cpus.size())
return sourceClass.op->emitError("cannot materialize batch communication between equivalence classes of different sizes");
return sourceClass.op->emitError(
"cannot materialize batch communication between equivalence classes of different sizes");
SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds;
@@ -991,7 +987,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
}
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
Value received =
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received;
return success();
@@ -1001,11 +998,11 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
}
LogicalResult emitHostCommunication(MaterializerState& state,
MaterializedClass& sourceClass,
ArrayRef<ProducerKey> keys,
Value payload,
Value originalOutput,
Location loc) {
MaterializedClass& sourceClass,
ArrayRef<ProducerKey> keys,
Value payload,
Value originalOutput,
Location loc) {
if (!hasLiveExternalUse(originalOutput, state.oldComputeOps))
return success();
@@ -1025,23 +1022,25 @@ LogicalResult emitHostCommunication(MaterializerState& state,
}
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received = appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
Value received =
appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
state.hostReplacements[originalOutput] = received;
return success();
}
LogicalResult emitOutputFanout(MaterializerState& state,
MaterializedClass& sourceClass,
ArrayRef<ProducerKey> keys,
Value payload,
Value originalOutput,
Location loc) {
MaterializedClass& sourceClass,
ArrayRef<ProducerKey> keys,
Value payload,
Value originalOutput,
Location loc) {
if (keys.empty())
return success();
if (!sourceClass.isBatch) {
for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front()))
if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
if (failed(
emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
return failure();
if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc)))
return failure();
@@ -1065,11 +1064,8 @@ LogicalResult emitOutputFanout(MaterializerState& state,
return success();
}
FailureOr<Value> materializeWholeBatchInput(MaterializerState& state,
MaterializedClass& targetClass,
ProducerKey key,
Type resultType,
Location loc) {
FailureOr<Value> materializeWholeBatchInput(
MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) {
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
auto resultTensorType = dyn_cast<RankedTensorType>(resultType);
if (!batch || !resultTensorType || resultTensorType.getRank() == 0)
@@ -1115,9 +1111,9 @@ FailureOr<Value> materializeWholeBatchInput(MaterializerState& state,
}
FailureOr<Value> resolveInputValue(MaterializerState& state,
MaterializedClass& targetClass,
Value input,
const ComputeInstance& consumerInstance) {
MaterializedClass& targetClass,
Value input,
const ComputeInstance& consumerInstance) {
if (isConstantLike(input))
return input;
@@ -1135,9 +1131,9 @@ FailureOr<Value> resolveInputValue(MaterializerState& state,
}
void mapWeights(MaterializerState& state,
MaterializedClass& targetClass,
const ComputeInstance& instance,
IRMapping& mapper) {
MaterializedClass& targetClass,
const ComputeInstance& instance,
IRMapping& mapper) {
Operation* op = instance.op;
if (auto compute = dyn_cast<SpatCompute>(op)) {
for (auto [index, weight] : llvm::enumerate(compute.getWeights()))
@@ -1151,9 +1147,9 @@ void mapWeights(MaterializerState& state,
}
LogicalResult mapInputs(MaterializerState& state,
MaterializedClass& targetClass,
const ComputeInstance& instance,
IRMapping& mapper) {
MaterializedClass& targetClass,
const ComputeInstance& instance,
IRMapping& mapper) {
Operation* op = instance.op;
if (auto compute = dyn_cast<SpatCompute>(op)) {
for (auto [index, input] : llvm::enumerate(compute.getInputs())) {
@@ -1200,9 +1196,8 @@ SmallVector<Value, 4> collectMappedBatchOutputs(SpatComputeBatch batch, IRMappin
return outputs;
}
FailureOr<SmallVector<Value, 4>> cloneInstanceBody(MaterializerState& state,
MaterializedClass& targetClass,
ArrayRef<ComputeInstance> peers) {
FailureOr<SmallVector<Value, 4>>
cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef<ComputeInstance> peers) {
assert(!peers.empty() && "expected at least one peer instance");
const ComputeInstance& instance = peers.front();
Operation* sourceOp = instance.op;
@@ -1318,9 +1313,8 @@ LogicalResult eraseOldComputeOps(MaterializerState& state) {
} // namespace
LogicalResult MergeScheduleMaterializer::run(func::FuncOp func,
const MergeScheduleResult& schedule,
int64_t& nextChannelId) {
LogicalResult
MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) {
if (schedule.dominanceOrderCompute.empty())
return success();
@@ -10,8 +10,7 @@ namespace spatial {
class MergeScheduleMaterializer {
public:
mlir::LogicalResult
run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId);
mlir::LogicalResult run(mlir::func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId);
};
} // namespace spatial
@@ -57,8 +57,7 @@ bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != n
class ScopedMergePhaseTimer {
public:
explicit ScopedMergePhaseTimer(StringRef phaseName)
: enabled(isMergeProfilingEnabled()),
phase(phaseName.str()) {
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
if (enabled)
start = std::chrono::steady_clock::now();
}
@@ -130,15 +129,12 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
MergeIrCounts counts = collectMergeIrCounts(funcOp);
llvm::errs() << "[merge-profile] " << phaseName << " counts:"
<< " compute=" << counts.topLevelComputeCount
<< " compute_batch=" << counts.topLevelComputeBatchCount
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
<< " scalar_send=" << counts.scalarChannelSendCount
<< " scalar_recv=" << counts.scalarChannelReceiveCount
<< " tensor_send=" << counts.tensorChannelSendCount
<< " tensor_recv=" << counts.tensorChannelReceiveCount
<< " wvmm=" << counts.wvmmCount
<< " vadd=" << counts.vaddCount
<< " scf_for=" << counts.scfForCount << "\n";
<< " tensor_recv=" << counts.tensorChannelReceiveCount << " wvmm=" << counts.wvmmCount
<< " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
}
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
@@ -167,7 +163,8 @@ bool isTrivialSerialMergeCandidate(SpatCompute compute) {
return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size();
}
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SmallVectorImpl<Value>& targetWeights, ValueRange sourceWeights) {
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SmallVectorImpl<Value>& targetWeights,
ValueRange sourceWeights) {
DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices;
for (auto [weightIndex, weight] : llvm::enumerate(targetWeights))
targetWeightIndices[weight].push_back(weightIndex);
@@ -7,6 +7,6 @@
namespace onnx_mlir {
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t &nextChannelId);
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
} // namespace onnx_mlir
@@ -707,8 +707,10 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
if (packedInput) {
SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
SmallVector<Value> sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
SmallVector<Value> targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
SmallVector<Value> sourceCoreIdValues =
createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
SmallVector<Value> targetCoreIdValues =
createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
spatial::SpatChannelSendTensorOp::create(
rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput);
for (auto op : run.ops)
@@ -39,11 +39,11 @@ struct ComputeGraph {
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
};
ComputeGraph buildComputeGraph(mlir::Operation *entryOp);
bool verifyAcyclic(const ComputeGraph &graph);
ComputeGraph buildComputeGraph(mlir::Operation* entryOp);
bool verifyAcyclic(const ComputeGraph& graph);
Weight getComputeInstanceWeight(const ComputeInstance &instance);
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance);
Weight getComputeInstanceWeight(const ComputeInstance& instance);
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance);
} // namespace spatial
} // namespace onnx_mlir
@@ -11,11 +11,11 @@ namespace onnx_mlir {
namespace spatial {
struct ComputeInstance {
mlir::Operation *op = nullptr;
mlir::Operation* op = nullptr;
uint32_t laneStart = 0;
uint32_t laneCount = 1;
bool operator==(const ComputeInstance &other) const {
bool operator==(const ComputeInstance& other) const {
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
}
};
@@ -29,16 +29,15 @@ namespace llvm {
template <>
struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> {
static onnx_mlir::spatial::ComputeInstance getEmptyKey() {
return {DenseMapInfo<mlir::Operation *>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
return {DenseMapInfo<mlir::Operation*>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
}
static onnx_mlir::spatial::ComputeInstance getTombstoneKey() {
return {DenseMapInfo<mlir::Operation *>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
}
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance &value) {
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance& value) {
return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
}
static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs,
const onnx_mlir::spatial::ComputeInstance &rhs) {
static bool isEqual(const onnx_mlir::spatial::ComputeInstance& lhs, const onnx_mlir::spatial::ComputeInstance& rhs) {
return lhs == rhs;
}
};
@@ -27,15 +27,15 @@ ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex)
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value,
const ComputeInstance *consumerInstance = nullptr);
const ComputeInstance* consumerInstance = nullptr);
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value,
const ComputeInstance *consumerInstance = nullptr);
const ComputeInstance* consumerInstance = nullptr);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance);
llvm::SmallVector<mlir::Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance);
mlir::Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance& instance);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance& instance);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceOutputValues(const ComputeInstance& instance);
llvm::SmallVector<mlir::Type, 4> getComputeInstanceOutputTypes(const ComputeInstance& instance);
mlir::Block& getComputeInstanceTemplateBlock(const ComputeInstance& instance);
} // namespace spatial
} // namespace onnx_mlir
@@ -10,8 +10,8 @@
#include <queue>
#include <vector>
#include "DcpScheduler.hpp"
#include "../DCPGraph/Graph.hpp"
#include "DcpScheduler.hpp"
namespace onnx_mlir {
namespace spatial {
@@ -47,7 +47,7 @@ struct WindowScheduleResult {
size_t maxMergeGroupSize = 0;
};
size_t getSchedulingCpuBudget(const DcpScheduleOptions &options) {
size_t getSchedulingCpuBudget(const DcpScheduleOptions& options) {
if (options.processorCount > 0)
return options.processorCount;
return std::numeric_limits<size_t>::max();
@@ -72,7 +72,7 @@ std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
for (auto [key, weight] : edgeWeights)
aggregatedEdges.push_back(
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
llvm::sort(aggregatedEdges, [](const IndexedEdge &lhs, const IndexedEdge &rhs) {
llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) {
if (std::get<0>(lhs) != std::get<0>(rhs))
return std::get<0>(lhs) < std::get<0>(rhs);
return std::get<1>(lhs) < std::get<1>(rhs);
@@ -80,7 +80,7 @@ std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
return aggregatedEdges;
}
VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) {
VirtualGraph buildInitialVirtualGraph(const ComputeGraph& graph) {
VirtualGraph virtualGraph;
virtualGraph.nodes.reserve(graph.nodes.size());
for (auto [index, node] : llvm::enumerate(graph.nodes)) {
@@ -93,14 +93,14 @@ VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) {
std::vector<IndexedEdge> edges;
edges.reserve(graph.edges.size());
for (const ComputeGraphEdge &edge : graph.edges)
for (const ComputeGraphEdge& edge : graph.edges)
edges.push_back(
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
virtualGraph.edges = aggregateEdges(edges);
return virtualGraph;
}
TimingInfo computeTiming(const VirtualGraph &graph) {
TimingInfo computeTiming(const VirtualGraph& graph) {
TimingInfo timing;
size_t nodeCount = graph.nodes.size();
timing.aest.assign(nodeCount, 0);
@@ -122,7 +122,7 @@ TimingInfo computeTiming(const VirtualGraph &graph) {
}
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
const VirtualNode &node = graph.nodes[nodeIndex];
const VirtualNode& node = graph.nodes[nodeIndex];
if (!node.originalNodeIndices.empty())
return node.originalNodeIndices.front();
return nodeIndex;
@@ -181,7 +181,7 @@ TimingInfo computeTiming(const VirtualGraph &graph) {
return timing;
}
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph &graph) {
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph& graph) {
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
for (auto [start, end, weight] : graph.edges) {
(void) weight;
@@ -191,14 +191,14 @@ std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph &gr
adjacency[startIndex].push_back(endIndex);
adjacency[endIndex].push_back(startIndex);
}
for (auto &neighbours : adjacency) {
for (auto& neighbours : adjacency) {
llvm::sort(neighbours);
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
}
return adjacency;
}
std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const TimingInfo &timing, size_t windowSize) {
std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) {
std::vector<size_t> ranked(timing.aest.size());
std::iota(ranked.begin(), ranked.end(), 0);
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
@@ -240,7 +240,7 @@ std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const Timing
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
auto addToWindow = [&](size_t node, const std::vector<char> &eligible) {
auto addToWindow = [&](size_t node, const std::vector<char>& eligible) {
if (inWindow[node])
return;
inWindow[node] = true;
@@ -288,7 +288,7 @@ std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const Timing
return selected;
}
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph &graph, const std::vector<int64_t> &nodeToWindowIndex) {
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
std::vector<IndexedEdge> windowEdges;
windowEdges.reserve(graph.edges.size());
for (auto [start, end, weight] : graph.edges) {
@@ -301,10 +301,10 @@ std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph &graph, const std::
return aggregateEdges(windowEdges);
}
WindowScheduleResult scheduleWindow(const VirtualGraph &graph,
WindowScheduleResult scheduleWindow(const VirtualGraph& graph,
llvm::ArrayRef<size_t> selectedNodes,
const DcpScheduleOptions &options,
mlir::MLIRContext *context) {
const DcpScheduleOptions& options,
mlir::MLIRContext* context) {
std::vector<Weight> windowWeights;
std::vector<CrossbarUsage> windowCrossbarUsage;
std::vector<int64_t> windowNodeOrderKeys;
@@ -338,17 +338,17 @@ WindowScheduleResult scheduleWindow(const VirtualGraph &graph,
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
std::vector<size_t> mergeGroup;
mergeGroup.reserve(scheduledTasks.size());
for (const auto &task : scheduledTasks)
for (const auto& task : scheduledTasks)
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
result.mergeGroups.push_back(std::move(mergeGroup));
}
return result;
}
bool coarsenGraph(const VirtualGraph &graph,
bool coarsenGraph(const VirtualGraph& graph,
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph &coarsenedGraph,
std::vector<size_t> &oldToNewNode) {
VirtualGraph& coarsenedGraph,
std::vector<size_t>& oldToNewNode) {
TimingInfo timing = computeTiming(graph);
std::vector<size_t> topologicalRank(graph.nodes.size());
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
@@ -358,7 +358,7 @@ bool coarsenGraph(const VirtualGraph &graph,
std::vector<std::vector<size_t>> orderedMergeGroups;
orderedMergeGroups.reserve(mergeGroups.size());
for (const auto &mergeGroup : mergeGroups) {
for (const auto& mergeGroup : mergeGroups) {
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
if (topologicalRank[lhs] != topologicalRank[rhs])
@@ -395,7 +395,7 @@ bool coarsenGraph(const VirtualGraph &graph,
continue;
}
auto &newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
auto& newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
if (newNodeIndex.has_value()) {
oldToNewNode[nodeIndex] = *newNodeIndex;
continue;
@@ -403,8 +403,9 @@ bool coarsenGraph(const VirtualGraph &graph,
VirtualNode mergedNode;
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
const VirtualNode &memberNode = graph.nodes[memberIndex];
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), memberNode.originalNodeIndices.end());
const VirtualNode& memberNode = graph.nodes[memberIndex];
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(),
memberNode.originalNodeIndices.end());
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
}
@@ -437,7 +438,7 @@ bool coarsenGraph(const VirtualGraph &graph,
return true;
}
size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &options) {
size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions& options) {
size_t windowSize = std::min(options.criticalWindowSize, nodeCount);
CPU maxCpuCount = std::max<CPU>(1, static_cast<CPU>(getSchedulingCpuBudget(options)));
if (nodeCount > static_cast<size_t>(maxCpuCount))
@@ -445,7 +446,7 @@ size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &op
return windowSize;
}
void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) {
void assignFeasibleAest(const ComputeGraph& graph, MergeScheduleResult& result) {
llvm::DenseMap<ComputeInstance, size_t> nodeIndexByInstance;
nodeIndexByInstance.reserve(graph.nodes.size());
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
@@ -458,7 +459,7 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result)
std::vector<std::vector<ScheduledEdge>> scheduledChildren(graph.nodes.size());
std::vector<size_t> incomingEdgeCount(graph.nodes.size(), 0);
for (const ComputeGraphEdge &edge : graph.edges) {
for (const ComputeGraphEdge& edge : graph.edges) {
const ComputeInstance sourceInstance = graph.nodes[edge.source].instance;
const ComputeInstance targetInstance = graph.nodes[edge.target].instance;
const size_t sourceCpu = result.computeToCpuMap.lookup(sourceInstance);
@@ -473,15 +474,15 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result)
}
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
for (const ComputeGraphNode &node : graph.nodes) {
for (const ComputeGraphNode& node : graph.nodes) {
size_t cpu = result.computeToCpuMap.lookup(node.instance);
size_t slot = result.computeToCpuSlotMap.lookup(node.instance);
tasksByCpu[cpu].push_back({slot, nodeIndexByInstance.lookup(node.instance)});
}
for (auto &entry : tasksByCpu) {
auto &scheduledTasks = entry.second;
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
for (auto& entry : tasksByCpu) {
auto& scheduledTasks = entry.second;
llvm::sort(scheduledTasks, [](const auto& lhs, const auto& rhs) {
if (lhs.first != rhs.first)
return lhs.first < rhs.first;
return lhs.second < rhs.second;
@@ -512,7 +513,7 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result)
readyNodes.pop();
processedNodeCount++;
for (const ScheduledEdge &edge : scheduledChildren[sourceIndex]) {
for (const ScheduledEdge& edge : scheduledChildren[sourceIndex]) {
startTimes[edge.target] = std::max(startTimes[edge.target], addOrMax(startTimes[sourceIndex], edge.delay));
assert(incomingEdgeCount[edge.target] > 0 && "scheduled incoming edge count underflow");
incomingEdgeCount[edge.target]--;
@@ -528,7 +529,7 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result)
result.computeToAestMap[node.instance] = startTimes[nodeIndex];
}
MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const ComputeGraph &originalGraph) {
MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph& graph, const ComputeGraph& originalGraph) {
MergeScheduleResult result;
TimingInfo timing = computeTiming(graph);
@@ -542,7 +543,7 @@ MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const
std::vector<size_t> originalNodeToCpu(originalGraph.nodes.size(), 0);
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
const VirtualNode &virtualNode = graph.nodes[virtualNodeIndex];
const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
for (size_t originalIndex : virtualNode.originalNodeIndices)
originalNodeToCpu[originalIndex] = cpu;
}
@@ -556,17 +557,17 @@ MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const
result.computeToCpuSlotMap[node.instance] = nextCpuSlot[cpu]++;
result.cpuToLastComputeMap[cpu] = node.instance;
}
for (const auto &[cpu, lastCompute] : result.cpuToLastComputeMap)
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
result.isLastComputeOfCpu.insert(lastCompute);
assignFeasibleAest(originalGraph, result);
return result;
}
MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const ComputeGraph &graph) {
MergeScheduleResult buildResultFromScheduledGraph(GraphDCP& graphDCP, const ComputeGraph& graph) {
MergeScheduleResult result;
result.dominanceOrderCompute.reserve(graph.nodes.size());
for (const ComputeGraphNode &node : graph.nodes)
for (const ComputeGraphNode& node : graph.nodes)
result.dominanceOrderCompute.push_back(node.instance);
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
@@ -589,7 +590,8 @@ MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const Comp
return result;
}
MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
MergeScheduleResult
runLegacyDcp(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context) {
llvm::SmallVector<Weight> nodeWeights;
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
llvm::SmallVector<int64_t> nodeOrderKeys;
@@ -599,12 +601,12 @@ MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOpt
nodeOrderKeys.reserve(graph.nodes.size());
edges.reserve(graph.edges.size());
for (const ComputeGraphNode &node : graph.nodes) {
for (const ComputeGraphNode& node : graph.nodes) {
nodeWeights.push_back(node.weight);
nodeCrossbarUsage.push_back(node.crossbarUsage);
nodeOrderKeys.push_back(static_cast<int64_t>(node.originalOrder));
}
for (const ComputeGraphEdge &edge : graph.edges) {
for (const ComputeGraphEdge& edge : graph.edges) {
edges.push_back(
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
}
@@ -617,11 +619,11 @@ MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOpt
return buildResultFromScheduledGraph(graphDCP, graph);
}
bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOptions &options) {
bool needsExactScheduledBatches(const ComputeGraph& graph, const DcpScheduleOptions& options) {
if (options.processorCount == 0 || !options.allowFallbackForAutoCoreCount)
return false;
size_t schedulingCpuBudget = getSchedulingCpuBudget(options);
return llvm::any_of(graph.nodes, [&](const ComputeGraphNode &node) {
return llvm::any_of(graph.nodes, [&](const ComputeGraphNode& node) {
auto batch = dyn_cast<SpatComputeBatch>(node.instance.op);
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
});
@@ -630,7 +632,7 @@ bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOpti
} // namespace
MergeScheduleResult
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
runDcpScheduler(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context) {
if (needsExactScheduledBatches(graph, options))
return runLegacyDcp(graph, options, context);
@@ -15,7 +15,7 @@ struct DcpScheduleOptions {
};
MergeScheduleResult
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context);
runDcpScheduler(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context);
} // namespace spatial
} // namespace onnx_mlir
@@ -1,13 +1,13 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <limits>
#include <vector>
#include "ComputeGraph.hpp"
#include "../DCPGraph/DCPAnalysis.hpp"
#include "ComputeGraph.hpp"
#include "DcpScheduler.hpp"
#include "MergeSchedulingAnalysis.hpp"
#include "PeftScheduler.hpp"
@@ -20,15 +20,13 @@ namespace {
MergeSchedulerKind getSchedulerKind() {
switch (pimMergeScheduler.getValue()) {
case MergeSchedulerPeft:
return MergeSchedulerKind::Peft;
case MergeSchedulerDcp:
return MergeSchedulerKind::Dcp;
case MergeSchedulerPeft: return MergeSchedulerKind::Peft;
case MergeSchedulerDcp: return MergeSchedulerKind::Dcp;
}
llvm_unreachable("unknown merge scheduler kind");
}
void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result, CrossbarUsage crossbarCapacity) {
void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result, CrossbarUsage crossbarCapacity) {
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
@@ -45,9 +43,9 @@ void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result
{result.computeToCpuSlotMap.lookup(instance), nodeIndex});
}
for (auto &entry : tasksByCpu) {
auto &scheduledTasks = entry.second;
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
for (auto& entry : tasksByCpu) {
auto& scheduledTasks = entry.second;
llvm::sort(scheduledTasks, [](const auto& lhs, const auto& rhs) {
if (lhs.first != rhs.first)
return lhs.first < rhs.first;
return lhs.second < rhs.second;
@@ -70,7 +68,7 @@ void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result
llvm::report_fatal_error("merge scheduling: missing last-compute marker");
}
for (const ComputeGraphEdge &edge : graph.edges) {
for (const ComputeGraphEdge& edge : graph.edges) {
const ComputeInstance source = graph.nodes[edge.source].instance;
const ComputeInstance target = graph.nodes[edge.target].instance;
const size_t sourceCpu = result.computeToCpuMap.lookup(source);
@@ -97,8 +95,8 @@ void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result
} // namespace
MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation *op)
: entryOp(op) {
MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation* op)
: entryOp(op) {
result = run();
}
@@ -115,20 +113,17 @@ MergeScheduleResult MergeSchedulingAnalysis::run() {
MergeScheduleResult schedule;
if (options.kind == MergeSchedulerKind::Peft) {
schedule = runPeftScheduler(
graph,
PeftScheduleOptions {options.processorCount, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
entryOp->getContext()});
schedule = runPeftScheduler(graph,
PeftScheduleOptions {options.processorCount,
static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
entryOp->getContext()});
}
else {
schedule = runDcpScheduler(
graph,
DcpScheduleOptions {
options.processorCount,
dcpCriticalWindowSize.getValue(),
options.allowDcpFallbackForAutoCoreCount
},
entryOp->getContext());
schedule = runDcpScheduler(graph,
DcpScheduleOptions {options.processorCount,
dcpCriticalWindowSize.getValue(),
options.allowDcpFallbackForAutoCoreCount},
entryOp->getContext());
}
verifySchedule(graph, schedule, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()));
@@ -22,11 +22,11 @@ struct MergeSchedulingOptions {
class MergeSchedulingAnalysis {
public:
explicit MergeSchedulingAnalysis(mlir::Operation *op);
MergeScheduleResult &getResult() { return result; }
explicit MergeSchedulingAnalysis(mlir::Operation* op);
MergeScheduleResult& getResult() { return result; }
private:
mlir::Operation *entryOp = nullptr;
mlir::Operation* entryOp = nullptr;
MergeScheduleResult result;
MergeScheduleResult run();
@@ -11,10 +11,10 @@ namespace spatial {
struct PeftScheduleOptions {
size_t processorCount = 0;
CrossbarUsage crossbarCapacity = 0;
mlir::MLIRContext *context = nullptr;
mlir::MLIRContext* context = nullptr;
};
MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftScheduleOptions &options);
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options);
} // namespace spatial
} // namespace onnx_mlir
+4 -3
View File
@@ -8,8 +8,8 @@
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -211,8 +211,9 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
(void) withScalarCoreFromBatchLane(
coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) { return verifyCoreOperands(scalarCore, diagnostics); });
(void) withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) {
return verifyCoreOperands(scalarCore, diagnostics);
});
continue;
}