automatic code reformat
This commit is contained in:
@@ -13,7 +13,8 @@
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
struct CappedDiagnosticReporter {
|
||||
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
|
||||
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8)
|
||||
: maxReportedFailures(maxReportedFailures) {}
|
||||
|
||||
template <typename EmitFn>
|
||||
void report(mlir::Operation* op, EmitFn&& emit) {
|
||||
@@ -24,8 +25,7 @@ struct CappedDiagnosticReporter {
|
||||
|
||||
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
|
||||
if (numFailures > maxReportedFailures)
|
||||
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
|
||||
<< failureDescription;
|
||||
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
|
||||
}
|
||||
|
||||
bool hasFailure() const { return numFailures != 0; }
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
#define DEBUG_TYPE "PimCompilerOptions"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -15,13 +15,13 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
|
||||
llvm::cl::init(EmitPimCodegen),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
|
||||
"pim-merge-scheduler",
|
||||
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
|
||||
llvm::cl::init(MergeSchedulerPeft),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
llvm::cl::opt<PimMergeSchedulerType>
|
||||
pimMergeScheduler("pim-merge-scheduler",
|
||||
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
|
||||
llvm::cl::init(MergeSchedulerPeft),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
pimOnlyCodegen("pim-only-codegen",
|
||||
|
||||
@@ -99,15 +99,17 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
|
||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
detail::invokeWithValues(
|
||||
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
|
||||
detail::invokeWithValues(std::forward<BodyFn>(body),
|
||||
detail::getInputBlockArgs(block, weights.size()),
|
||||
std::make_index_sequence<NumInputs> {});
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult = detail::invokeWithValues(
|
||||
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
|
||||
auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
|
||||
detail::getInputBlockArgs(block, weights.size()),
|
||||
std::make_index_sequence<NumInputs> {});
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
|
||||
@@ -423,8 +423,11 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(aHSlices[coreId].size());
|
||||
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
|
||||
vmmOutputs.push_back(spatial::SpatVMMOp::create(
|
||||
rewriter, gemmLoc, currOutHSliceType, computeOp.getWeightArgument(aHSliceId), computeOp.getInputArgument(aHSliceId)));
|
||||
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter,
|
||||
gemmLoc,
|
||||
currOutHSliceType,
|
||||
computeOp.getWeightArgument(aHSliceId),
|
||||
computeOp.getInputArgument(aHSliceId)));
|
||||
if (vmmOutputs.empty()) {
|
||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||
return failure();
|
||||
@@ -579,8 +582,8 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
|
||||
tensor::ParallelInsertSliceOp::create(rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes,
|
||||
unitStrides);
|
||||
tensor::ParallelInsertSliceOp::create(
|
||||
rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, unitStrides);
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
|
||||
rewriter.replaceOp(gemmOp, batchOp.getResults());
|
||||
|
||||
@@ -38,23 +38,16 @@ static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t
|
||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||
}
|
||||
|
||||
static Value collapseBatchDims(Value value,
|
||||
int64_t batchSize,
|
||||
int64_t rows,
|
||||
int64_t cols,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value
|
||||
collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
if (type.getRank() == 2 || type.getRank() == 3)
|
||||
return value;
|
||||
|
||||
auto collapsedType =
|
||||
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||
SmallVector<ReassociationIndices> reassociation = {
|
||||
ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
|
||||
};
|
||||
auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}};
|
||||
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
||||
reassociation.front().push_back(dim);
|
||||
|
||||
@@ -72,19 +65,14 @@ static Value collapseBatchDims(Value value,
|
||||
return collapseCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value expandBatchDims(Value value,
|
||||
RankedTensorType outputType,
|
||||
size_t batchRank,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value
|
||||
expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) {
|
||||
if (cast<RankedTensorType>(value.getType()) == outputType)
|
||||
return value;
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation = {
|
||||
ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
|
||||
};
|
||||
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}};
|
||||
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
|
||||
|
||||
@@ -58,24 +58,21 @@ static Value buildNearestResizeLoop(Value input,
|
||||
|
||||
Value outputC = channelLoop.getInductionVar();
|
||||
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
|
||||
Value inputC =
|
||||
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
||||
Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
||||
|
||||
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
|
||||
rewriter.setInsertionPointToStart(heightLoop.getBody());
|
||||
|
||||
Value outputH = heightLoop.getInductionVar();
|
||||
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
|
||||
Value inputH =
|
||||
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
||||
Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
||||
|
||||
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
|
||||
rewriter.setInsertionPointToStart(widthLoop.getBody());
|
||||
|
||||
Value outputW = widthLoop.getInductionVar();
|
||||
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
|
||||
Value inputW =
|
||||
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
||||
Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
||||
|
||||
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
|
||||
Value inputSlice =
|
||||
@@ -114,8 +111,8 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
|
||||
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
||||
|| resizeOp.getNearestMode() != "floor")
|
||||
return rewriter.notifyMatchFailure(
|
||||
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
|
||||
return rewriter.notifyMatchFailure(resizeOp,
|
||||
"resize lowering currently supports only nearest + asymmetric + floor.");
|
||||
|
||||
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
||||
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
||||
|
||||
@@ -94,8 +94,8 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
}
|
||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
@@ -228,7 +228,8 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
}
|
||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults()))
|
||||
mapper.map(compute.getOutputArgument(resultIndex), newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
||||
mapper.map(compute.getOutputArgument(resultIndex),
|
||||
newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
||||
|
||||
for (Operation& op : oldBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
@@ -141,8 +141,8 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
|
||||
}
|
||||
|
||||
totalOffset = totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult()
|
||||
: scaledOffset;
|
||||
totalOffset =
|
||||
totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() : scaledOffset;
|
||||
}
|
||||
|
||||
if (!totalOffset)
|
||||
|
||||
@@ -30,8 +30,8 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
||||
Value replacement) {
|
||||
Block& body = owner->getRegion(0).front();
|
||||
BlockArgument bodyArgument = isa<spatial::SpatCompute>(owner)
|
||||
? cast<spatial::SpatCompute>(owner).getInputArgument(inputIndex)
|
||||
: cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
|
||||
? cast<spatial::SpatCompute>(owner).getInputArgument(inputIndex)
|
||||
: cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
|
||||
unsigned bodyArgIndex = bodyArgument.getArgNumber();
|
||||
|
||||
rewriter.startOpModification(owner);
|
||||
|
||||
@@ -233,10 +233,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
if (!computeOp.getWeights().empty())
|
||||
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
auto coreOp = PimCoreOp::create(rewriter,
|
||||
loc,
|
||||
ValueRange(computeWeights),
|
||||
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||
auto coreOp = PimCoreOp::create(
|
||||
rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||
|
||||
@@ -217,8 +217,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
|
||||
} // namespace
|
||||
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(
|
||||
patterns.getContext());
|
||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -546,11 +546,8 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
||||
return ReturnPathLoweringResult::NotReturnPath;
|
||||
}
|
||||
|
||||
raptor::SpatialToPimPass::ReturnPathLoweringResult
|
||||
raptor::SpatialToPimPass::lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||
OpResult result,
|
||||
Value yieldValue,
|
||||
IRRewriter& rewriter) {
|
||||
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
|
||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
|
||||
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
|
||||
}
|
||||
|
||||
|
||||
@@ -318,7 +318,7 @@ void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp f
|
||||
}
|
||||
|
||||
LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
|
||||
IRRewriter& rewriter) {
|
||||
IRRewriter& rewriter) {
|
||||
Location loc = funcOp.getLoc();
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
|
||||
|
||||
@@ -39,13 +39,10 @@ private:
|
||||
size_t coreId = 0;
|
||||
llvm::SmallVector<mlir::Operation*> operationsToRemove;
|
||||
|
||||
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp,
|
||||
mlir::IRRewriter& rewriter);
|
||||
mlir::LogicalResult lowerComputeOp(spatial::SpatCompute computeOp,
|
||||
mlir::IRRewriter& rewriter,
|
||||
mlir::OperationFolder& constantFolder);
|
||||
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||
mlir::IRRewriter& rewriter);
|
||||
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||
mlir::LogicalResult
|
||||
lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder);
|
||||
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter);
|
||||
|
||||
enum class ReturnPathLoweringResult {
|
||||
Handled,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
@@ -56,7 +56,8 @@ static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<O
|
||||
return parser.parseRParen();
|
||||
}
|
||||
|
||||
static void printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
||||
static void
|
||||
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
||||
printCompressedValueList(printer, arguments, delimiter);
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, delimiter);
|
||||
@@ -82,10 +83,8 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
|
||||
|
||||
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
||||
switch (currentDelimiter) {
|
||||
case ListDelimiter::Paren:
|
||||
return parser.parseRParen();
|
||||
case ListDelimiter::Square:
|
||||
return parser.parseRSquare();
|
||||
case ListDelimiter::Paren: return parser.parseRParen();
|
||||
case ListDelimiter::Square: return parser.parseRSquare();
|
||||
}
|
||||
llvm_unreachable("unsupported delimiter");
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
@@ -66,8 +66,8 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
|
||||
|| isExplicitHostOperand(op, operand.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
InFlightDiagnostic diagnostic =
|
||||
ownerOp->emitOpError() << kind << " body may only directly reference external constants";
|
||||
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
||||
<< kind << " body may only directly reference external constants";
|
||||
diagnostic.attachNote(op->getLoc())
|
||||
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
||||
hasFailure = true;
|
||||
@@ -153,10 +153,9 @@ LogicalResult PimCoreOp::verify() {
|
||||
Block& block = getBody().front();
|
||||
if (block.getNumArguments() != getWeights().size())
|
||||
return emitError("core body must have one block argument per weight");
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("core weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
|
||||
}
|
||||
|
||||
@@ -169,14 +168,12 @@ LogicalResult PimCoreBatchOp::verify() {
|
||||
return emitError("core_batch body must have lane, weight, and input block arguments");
|
||||
if (!getLaneArgument().getType().isIndex())
|
||||
return emitError("core_batch first block argument must have index type");
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("core_batch weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
|
||||
if (getInputArgument(inputIndex).getType() != input.getType())
|
||||
return emitError("core_batch input block argument types must match input operand types exactly");
|
||||
}
|
||||
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
|
||||
}
|
||||
|
||||
|
||||
@@ -50,10 +50,9 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
|
||||
pendingValues.push_back(result);
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
|
||||
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
|
||||
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs()))
|
||||
if (initArg == value)
|
||||
pendingValues.push_back(forOp.getResult(index));
|
||||
}
|
||||
}
|
||||
|
||||
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -65,9 +65,7 @@ void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
|
||||
|
||||
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
|
||||
|
||||
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() {
|
||||
return getRegion().front().getOperations();
|
||||
}
|
||||
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
|
||||
|
||||
void SpatialDialect::initialize() {
|
||||
addTypes<
|
||||
|
||||
@@ -104,16 +104,13 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
|
||||
return failure();
|
||||
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
||||
switch (currentDelimiter) {
|
||||
case ListDelimiter::Paren:
|
||||
return parser.parseRParen();
|
||||
case ListDelimiter::Square:
|
||||
return parser.parseRSquare();
|
||||
case ListDelimiter::Paren: return parser.parseRParen();
|
||||
case ListDelimiter::Square: return parser.parseRSquare();
|
||||
}
|
||||
llvm_unreachable("unsupported delimiter");
|
||||
};
|
||||
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) {
|
||||
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -272,9 +272,9 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
|
||||
continue;
|
||||
|
||||
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
||||
<< kind << " body may only directly reference external constants";
|
||||
diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber()
|
||||
<< " is used by " << op->getName();
|
||||
<< kind << " body may only directly reference external constants";
|
||||
diagnostic.attachNote(op->getLoc())
|
||||
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
||||
hasFailure = true;
|
||||
}
|
||||
});
|
||||
@@ -457,14 +457,12 @@ LogicalResult SpatCompute::verify() {
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute body must have weight and input block arguments");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("compute weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
|
||||
if (getInputArgument(inputIndex).getType() != input.getType())
|
||||
return emitError("compute input block argument types must match input operand types exactly");
|
||||
}
|
||||
|
||||
if (block.mightHaveTerminator()) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
@@ -582,10 +580,9 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
if (!getLaneArgument().getType().isIndex())
|
||||
return emitError("compute_batch first block argument must have index type");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("compute_batch weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
BlockArgument blockArg = getInputArgument(inputIndex);
|
||||
if (blockArg.getType() != input.getType())
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "DCPAnalysis.hpp"
|
||||
#include "../Scheduling/ComputeGraph.hpp"
|
||||
#include "../Scheduling/DcpScheduler.hpp"
|
||||
#include "DCPAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
@@ -11,15 +11,15 @@ using DCPAnalysisResult = MergeScheduleResult;
|
||||
struct DCPAnalysis {
|
||||
private:
|
||||
DCPAnalysisResult result;
|
||||
mlir::Operation *entryOp;
|
||||
mlir::Operation* entryOp;
|
||||
DCPAnalysisResult run();
|
||||
|
||||
public:
|
||||
DCPAnalysis(mlir::Operation *op)
|
||||
DCPAnalysis(mlir::Operation* op)
|
||||
: entryOp(op) {
|
||||
result = run();
|
||||
}
|
||||
DCPAnalysisResult &getResult() { return result; }
|
||||
DCPAnalysisResult& getResult() { return result; }
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
+119
-125
@@ -1,5 +1,3 @@
|
||||
#include "MaterializeMergeSchedule.hpp"
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
@@ -24,6 +22,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "MaterializeMergeSchedule.hpp"
|
||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
@@ -72,9 +71,7 @@ struct CpuSlotKey {
|
||||
};
|
||||
|
||||
struct CpuSlotKeyInfo {
|
||||
static CpuSlotKey getEmptyKey() {
|
||||
return {std::numeric_limits<CpuId>::max(), std::numeric_limits<SlotId>::max()};
|
||||
}
|
||||
static CpuSlotKey getEmptyKey() { return {std::numeric_limits<CpuId>::max(), std::numeric_limits<SlotId>::max()}; }
|
||||
|
||||
static CpuSlotKey getTombstoneKey() {
|
||||
return {std::numeric_limits<CpuId>::max() - 1, std::numeric_limits<SlotId>::max()};
|
||||
@@ -139,11 +136,12 @@ struct MaterializerState {
|
||||
DenseMap<Value, Value> hostReplacements;
|
||||
DenseSet<Operation*> oldComputeOps;
|
||||
|
||||
MaterializerState(func::FuncOp func,
|
||||
const MergeScheduleResult& schedule,
|
||||
int64_t& nextChannelId)
|
||||
: func(func), schedule(schedule), rewriter(func.getContext()), constantFolder(func.getContext()),
|
||||
nextChannelId(nextChannelId) {}
|
||||
MaterializerState(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId)
|
||||
: func(func),
|
||||
schedule(schedule),
|
||||
rewriter(func.getContext()),
|
||||
constantFolder(func.getContext()),
|
||||
nextChannelId(nextChannelId) {}
|
||||
};
|
||||
|
||||
bool isConstantLike(Value value) {
|
||||
@@ -189,11 +187,12 @@ std::optional<uint32_t> getConstantFirstSliceOffset(tensor::ExtractSliceOp extra
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
ProducerKey getBatchLaneProducerKey(SpatComputeBatch batch,
|
||||
uint32_t laneStart,
|
||||
uint32_t laneCount,
|
||||
size_t resultIndex) {
|
||||
return {{batch.getOperation(), laneStart, laneCount}, resultIndex};
|
||||
ProducerKey
|
||||
getBatchLaneProducerKey(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount, size_t resultIndex) {
|
||||
return {
|
||||
{batch.getOperation(), laneStart, laneCount},
|
||||
resultIndex
|
||||
};
|
||||
}
|
||||
|
||||
ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex) {
|
||||
@@ -202,8 +201,8 @@ ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex)
|
||||
|
||||
bool isWholeBatchProducerKey(ProducerKey key) {
|
||||
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
|
||||
return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0 &&
|
||||
key.instance.laneCount == static_cast<uint32_t>(batch.getLaneCount());
|
||||
return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0
|
||||
&& key.instance.laneCount == static_cast<uint32_t>(batch.getLaneCount());
|
||||
}
|
||||
|
||||
SmallVector<ProducerKey, 8> expandWholeBatchProducerKey(ProducerKey key) {
|
||||
@@ -306,7 +305,10 @@ std::optional<ProducerKey> getProducerKey(Value value, const ComputeInstance* co
|
||||
auto result = dyn_cast<OpResult>(value);
|
||||
if (!result)
|
||||
return std::nullopt;
|
||||
return ProducerKey {{compute.getOperation(), 0, 1}, result.getResultNumber()};
|
||||
return ProducerKey {
|
||||
{compute.getOperation(), 0, 1},
|
||||
result.getResultNumber()
|
||||
};
|
||||
}
|
||||
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(definingOp)) {
|
||||
@@ -316,10 +318,8 @@ std::optional<ProducerKey> getProducerKey(Value value, const ComputeInstance* co
|
||||
|
||||
if (batch.getNumResults() != 0) {
|
||||
if (consumerInstance && isa<SpatComputeBatch>(consumerInstance->op))
|
||||
return getBatchLaneProducerKey(batch,
|
||||
consumerInstance->laneStart,
|
||||
consumerInstance->laneCount,
|
||||
result.getResultNumber());
|
||||
return getBatchLaneProducerKey(
|
||||
batch, consumerInstance->laneStart, consumerInstance->laneCount, result.getResultNumber());
|
||||
|
||||
return getWholeBatchProducerKey(batch, result.getResultNumber());
|
||||
}
|
||||
@@ -489,12 +489,13 @@ void createEmptyMaterializedOps(MaterializerState& state) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto batch = SpatComputeBatch::create(state.rewriter,
|
||||
loc,
|
||||
TypeRange(resultTypes),
|
||||
state.rewriter.getI32IntegerAttr(static_cast<int32_t>(materializedClass.cpus.size())),
|
||||
ValueRange {},
|
||||
ValueRange {});
|
||||
auto batch =
|
||||
SpatComputeBatch::create(state.rewriter,
|
||||
loc,
|
||||
TypeRange(resultTypes),
|
||||
state.rewriter.getI32IntegerAttr(static_cast<int32_t>(materializedClass.cpus.size())),
|
||||
ValueRange {},
|
||||
ValueRange {});
|
||||
batch.getProperties().setOperandSegmentSizes({0, 0});
|
||||
SmallVector<int32_t> coreIds;
|
||||
coreIds.reserve(materializedClass.cpus.size());
|
||||
@@ -506,15 +507,13 @@ void createEmptyMaterializedOps(MaterializerState& state) {
|
||||
SmallVector<Location, 4> blockArgLocs {loc};
|
||||
llvm::append_range(blockArgTypes, resultTypes);
|
||||
blockArgLocs.append(resultTypes.size(), loc);
|
||||
Block* body = state.rewriter.createBlock(
|
||||
&batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
Block* body =
|
||||
state.rewriter.createBlock(&batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
state.rewriter.setInsertionPointToEnd(body);
|
||||
if (resultTypes.empty()) {
|
||||
if (resultTypes.empty())
|
||||
SpatYieldOp::create(state.rewriter, loc, ValueRange {});
|
||||
}
|
||||
else {
|
||||
else
|
||||
SpatInParallelOp::create(state.rewriter, loc);
|
||||
}
|
||||
materializedClass.op = batch.getOperation();
|
||||
materializedClass.body = body;
|
||||
state.rewriter.setInsertionPointAfter(batch.getOperation());
|
||||
@@ -559,7 +558,8 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali
|
||||
else {
|
||||
cast<SpatComputeBatch>(materializedClass.op).getInputsMutable().append(ValueRange(input));
|
||||
setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size());
|
||||
BlockArgument arg = materializedClass.body->insertArgument(materializedClass.body->getNumArguments()-1, input.getType(), input.getLoc());
|
||||
BlockArgument arg = materializedClass.body->insertArgument(
|
||||
materializedClass.body->getNumArguments() - 1, input.getType(), input.getLoc());
|
||||
materializedClass.inputArgs[input] = arg;
|
||||
return arg;
|
||||
}
|
||||
@@ -586,9 +586,8 @@ SmallVector<Value, 8> createIndexConstants(MaterializerState& state, Operation*
|
||||
return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened));
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<ComputeInstance, 8>> getPeerInstances(MaterializerState& state,
|
||||
const MaterializedClass& materializedClass,
|
||||
SlotId slot) {
|
||||
FailureOr<SmallVector<ComputeInstance, 8>>
|
||||
getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) {
|
||||
SmallVector<ComputeInstance, 8> peers;
|
||||
peers.reserve(materializedClass.cpus.size());
|
||||
for (CpuId cpu : materializedClass.cpus) {
|
||||
@@ -601,9 +600,9 @@ FailureOr<SmallVector<ComputeInstance, 8>> getPeerInstances(MaterializerState& s
|
||||
}
|
||||
|
||||
Value createOriginalLaneValue(MaterializerState& state,
|
||||
MaterializedClass& materializedClass,
|
||||
ArrayRef<ComputeInstance> peers,
|
||||
Location loc) {
|
||||
MaterializedClass& materializedClass,
|
||||
ArrayRef<ComputeInstance> peers,
|
||||
Location loc) {
|
||||
assert(!peers.empty() && "expected at least one peer instance");
|
||||
if (!materializedClass.isBatch)
|
||||
return createIndexConstant(state, materializedClass.op, peers.front().laneStart);
|
||||
@@ -751,12 +750,12 @@ SmallVector<ClassId, 4> getSortedDestinationClasses(MaterializerState& state, Pr
|
||||
}
|
||||
|
||||
Value appendReceive(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
Location loc) {
|
||||
MaterializedClass& targetClass,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
Location loc) {
|
||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, targetClass.op, channelIds);
|
||||
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, targetClass.op, sourceCoreIds);
|
||||
@@ -774,22 +773,18 @@ Value appendReceive(MaterializerState& state,
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
return SpatChannelReceiveOp::create(state.rewriter,
|
||||
loc,
|
||||
type,
|
||||
channelIdValues.front(),
|
||||
sourceCoreIdValues.front(),
|
||||
targetCoreIdValues.front())
|
||||
return SpatChannelReceiveOp::create(
|
||||
state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
Value appendHostReceive(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
Location loc) {
|
||||
MaterializedClass& sourceClass,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
Location loc) {
|
||||
state.rewriter.setInsertionPointAfter(sourceClass.op);
|
||||
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, sourceClass.op, channelIds);
|
||||
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds);
|
||||
@@ -802,19 +797,13 @@ Value appendHostReceive(MaterializerState& state,
|
||||
}
|
||||
|
||||
assert(channelIds.size() == 1 && "scalar host receive expects one channel");
|
||||
return SpatChannelReceiveOp::create(state.rewriter,
|
||||
loc,
|
||||
type,
|
||||
channelIdValues.front(),
|
||||
sourceCoreIdValues.front(),
|
||||
targetCoreIdValues.front())
|
||||
return SpatChannelReceiveOp::create(
|
||||
state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
LogicalResult setHostOutputValue(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
Value originalOutput,
|
||||
Value payload) {
|
||||
LogicalResult
|
||||
setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) {
|
||||
auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput);
|
||||
if (resultIt == sourceClass.hostOutputToResultIndex.end())
|
||||
return sourceClass.op->emitError("missing host result slot for materialized output");
|
||||
@@ -864,12 +853,12 @@ LogicalResult setHostOutputValue(MaterializerState& state,
|
||||
}
|
||||
|
||||
void appendScalarSend(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
Value payload,
|
||||
int64_t channelId,
|
||||
int32_t sourceCoreId,
|
||||
int32_t targetCoreId,
|
||||
Location loc) {
|
||||
MaterializedClass& sourceClass,
|
||||
Value payload,
|
||||
int64_t channelId,
|
||||
int32_t sourceCoreId,
|
||||
int32_t targetCoreId,
|
||||
Location loc) {
|
||||
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||
Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId);
|
||||
Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId);
|
||||
@@ -878,12 +867,12 @@ void appendScalarSend(MaterializerState& state,
|
||||
}
|
||||
|
||||
void appendBatchSend(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
Value payload,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
Location loc) {
|
||||
MaterializedClass& sourceClass,
|
||||
Value payload,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
Location loc) {
|
||||
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||
SmallVector<Value, 8> channelIdValues = createIndexConstants(state, sourceClass.op, channelIds);
|
||||
SmallVector<Value, 8> sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds);
|
||||
@@ -892,11 +881,11 @@ void appendBatchSend(MaterializerState& state,
|
||||
}
|
||||
|
||||
LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
MaterializedClass& targetClass,
|
||||
ArrayRef<ProducerKey> keys,
|
||||
Value payload,
|
||||
Location loc) {
|
||||
MaterializedClass& sourceClass,
|
||||
MaterializedClass& targetClass,
|
||||
ArrayRef<ProducerKey> keys,
|
||||
Value payload,
|
||||
Location loc) {
|
||||
if (sourceClass.id == targetClass.id) {
|
||||
for (ProducerKey key : keys)
|
||||
state.availableValues[key][targetClass.id] = payload;
|
||||
@@ -908,8 +897,13 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
||||
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
|
||||
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
|
||||
appendScalarSend(state, sourceClass, payload, channelId, sourceCpu, targetCpu, loc);
|
||||
Value received = appendReceive(state, targetClass, payload.getType(), ArrayRef<int64_t>(channelId),
|
||||
ArrayRef<int32_t>(sourceCpu), ArrayRef<int32_t>(targetCpu), loc);
|
||||
Value received = appendReceive(state,
|
||||
targetClass,
|
||||
payload.getType(),
|
||||
ArrayRef<int64_t>(channelId),
|
||||
ArrayRef<int32_t>(sourceCpu),
|
||||
ArrayRef<int32_t>(targetCpu),
|
||||
loc);
|
||||
for (ProducerKey key : keys)
|
||||
state.availableValues[key][targetClass.id] = received;
|
||||
return success();
|
||||
@@ -937,7 +931,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
||||
loc);
|
||||
}
|
||||
|
||||
Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||
Value received =
|
||||
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||
for (ProducerKey key : keys)
|
||||
state.availableValues[key][targetClass.id] = received;
|
||||
return success();
|
||||
@@ -946,8 +941,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
||||
if (sourceClass.isBatch && !targetClass.isBatch) {
|
||||
std::optional<ProducerKey> packedKey = getContiguousProducerKeyForKeys(keys);
|
||||
if (!packedKey)
|
||||
return sourceClass.op->emitError(
|
||||
"cannot materialize batch-to-scalar communication as concat because source lanes are not contiguous in send order");
|
||||
return sourceClass.op->emitError("cannot materialize batch-to-scalar communication as concat because source "
|
||||
"lanes are not contiguous in send order");
|
||||
|
||||
FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(payload.getType(), keys.size());
|
||||
if (failed(packedType))
|
||||
@@ -975,7 +970,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
||||
|
||||
if (sourceClass.isBatch && targetClass.isBatch) {
|
||||
if (sourceClass.cpus.size() != targetClass.cpus.size())
|
||||
return sourceClass.op->emitError("cannot materialize batch communication between equivalence classes of different sizes");
|
||||
return sourceClass.op->emitError(
|
||||
"cannot materialize batch communication between equivalence classes of different sizes");
|
||||
|
||||
SmallVector<int64_t, 8> channelIds;
|
||||
SmallVector<int32_t, 8> sourceCoreIds;
|
||||
@@ -991,7 +987,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
||||
}
|
||||
|
||||
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||
Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||
Value received =
|
||||
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||
for (ProducerKey key : keys)
|
||||
state.availableValues[key][targetClass.id] = received;
|
||||
return success();
|
||||
@@ -1001,11 +998,11 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
||||
}
|
||||
|
||||
LogicalResult emitHostCommunication(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
ArrayRef<ProducerKey> keys,
|
||||
Value payload,
|
||||
Value originalOutput,
|
||||
Location loc) {
|
||||
MaterializedClass& sourceClass,
|
||||
ArrayRef<ProducerKey> keys,
|
||||
Value payload,
|
||||
Value originalOutput,
|
||||
Location loc) {
|
||||
if (!hasLiveExternalUse(originalOutput, state.oldComputeOps))
|
||||
return success();
|
||||
|
||||
@@ -1025,23 +1022,25 @@ LogicalResult emitHostCommunication(MaterializerState& state,
|
||||
}
|
||||
|
||||
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||
Value received = appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||
Value received =
|
||||
appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||
state.hostReplacements[originalOutput] = received;
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult emitOutputFanout(MaterializerState& state,
|
||||
MaterializedClass& sourceClass,
|
||||
ArrayRef<ProducerKey> keys,
|
||||
Value payload,
|
||||
Value originalOutput,
|
||||
Location loc) {
|
||||
MaterializedClass& sourceClass,
|
||||
ArrayRef<ProducerKey> keys,
|
||||
Value payload,
|
||||
Value originalOutput,
|
||||
Location loc) {
|
||||
if (keys.empty())
|
||||
return success();
|
||||
|
||||
if (!sourceClass.isBatch) {
|
||||
for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front()))
|
||||
if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
||||
if (failed(
|
||||
emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
||||
return failure();
|
||||
if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc)))
|
||||
return failure();
|
||||
@@ -1065,11 +1064,8 @@ LogicalResult emitOutputFanout(MaterializerState& state,
|
||||
return success();
|
||||
}
|
||||
|
||||
FailureOr<Value> materializeWholeBatchInput(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
ProducerKey key,
|
||||
Type resultType,
|
||||
Location loc) {
|
||||
FailureOr<Value> materializeWholeBatchInput(
|
||||
MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) {
|
||||
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
|
||||
auto resultTensorType = dyn_cast<RankedTensorType>(resultType);
|
||||
if (!batch || !resultTensorType || resultTensorType.getRank() == 0)
|
||||
@@ -1115,9 +1111,9 @@ FailureOr<Value> materializeWholeBatchInput(MaterializerState& state,
|
||||
}
|
||||
|
||||
FailureOr<Value> resolveInputValue(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
Value input,
|
||||
const ComputeInstance& consumerInstance) {
|
||||
MaterializedClass& targetClass,
|
||||
Value input,
|
||||
const ComputeInstance& consumerInstance) {
|
||||
if (isConstantLike(input))
|
||||
return input;
|
||||
|
||||
@@ -1135,9 +1131,9 @@ FailureOr<Value> resolveInputValue(MaterializerState& state,
|
||||
}
|
||||
|
||||
void mapWeights(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
const ComputeInstance& instance,
|
||||
IRMapping& mapper) {
|
||||
MaterializedClass& targetClass,
|
||||
const ComputeInstance& instance,
|
||||
IRMapping& mapper) {
|
||||
Operation* op = instance.op;
|
||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||
for (auto [index, weight] : llvm::enumerate(compute.getWeights()))
|
||||
@@ -1151,9 +1147,9 @@ void mapWeights(MaterializerState& state,
|
||||
}
|
||||
|
||||
LogicalResult mapInputs(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
const ComputeInstance& instance,
|
||||
IRMapping& mapper) {
|
||||
MaterializedClass& targetClass,
|
||||
const ComputeInstance& instance,
|
||||
IRMapping& mapper) {
|
||||
Operation* op = instance.op;
|
||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||
for (auto [index, input] : llvm::enumerate(compute.getInputs())) {
|
||||
@@ -1200,9 +1196,8 @@ SmallVector<Value, 4> collectMappedBatchOutputs(SpatComputeBatch batch, IRMappin
|
||||
return outputs;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Value, 4>> cloneInstanceBody(MaterializerState& state,
|
||||
MaterializedClass& targetClass,
|
||||
ArrayRef<ComputeInstance> peers) {
|
||||
FailureOr<SmallVector<Value, 4>>
|
||||
cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef<ComputeInstance> peers) {
|
||||
assert(!peers.empty() && "expected at least one peer instance");
|
||||
const ComputeInstance& instance = peers.front();
|
||||
Operation* sourceOp = instance.op;
|
||||
@@ -1318,9 +1313,8 @@ LogicalResult eraseOldComputeOps(MaterializerState& state) {
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult MergeScheduleMaterializer::run(func::FuncOp func,
|
||||
const MergeScheduleResult& schedule,
|
||||
int64_t& nextChannelId) {
|
||||
LogicalResult
|
||||
MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) {
|
||||
if (schedule.dominanceOrderCompute.empty())
|
||||
return success();
|
||||
|
||||
|
||||
@@ -10,8 +10,7 @@ namespace spatial {
|
||||
|
||||
class MergeScheduleMaterializer {
|
||||
public:
|
||||
mlir::LogicalResult
|
||||
run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId);
|
||||
mlir::LogicalResult run(mlir::func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId);
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
@@ -57,8 +57,7 @@ bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != n
|
||||
class ScopedMergePhaseTimer {
|
||||
public:
|
||||
explicit ScopedMergePhaseTimer(StringRef phaseName)
|
||||
: enabled(isMergeProfilingEnabled()),
|
||||
phase(phaseName.str()) {
|
||||
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
|
||||
if (enabled)
|
||||
start = std::chrono::steady_clock::now();
|
||||
}
|
||||
@@ -130,15 +129,12 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
|
||||
|
||||
MergeIrCounts counts = collectMergeIrCounts(funcOp);
|
||||
llvm::errs() << "[merge-profile] " << phaseName << " counts:"
|
||||
<< " compute=" << counts.topLevelComputeCount
|
||||
<< " compute_batch=" << counts.topLevelComputeBatchCount
|
||||
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
|
||||
<< " scalar_send=" << counts.scalarChannelSendCount
|
||||
<< " scalar_recv=" << counts.scalarChannelReceiveCount
|
||||
<< " tensor_send=" << counts.tensorChannelSendCount
|
||||
<< " tensor_recv=" << counts.tensorChannelReceiveCount
|
||||
<< " wvmm=" << counts.wvmmCount
|
||||
<< " vadd=" << counts.vaddCount
|
||||
<< " scf_for=" << counts.scfForCount << "\n";
|
||||
<< " tensor_recv=" << counts.tensorChannelReceiveCount << " wvmm=" << counts.wvmmCount
|
||||
<< " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
|
||||
}
|
||||
|
||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
@@ -167,7 +163,8 @@ bool isTrivialSerialMergeCandidate(SpatCompute compute) {
|
||||
return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size();
|
||||
}
|
||||
|
||||
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SmallVectorImpl<Value>& targetWeights, ValueRange sourceWeights) {
|
||||
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SmallVectorImpl<Value>& targetWeights,
|
||||
ValueRange sourceWeights) {
|
||||
DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(targetWeights))
|
||||
targetWeightIndices[weight].push_back(weightIndex);
|
||||
|
||||
@@ -7,6 +7,6 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t &nextChannelId);
|
||||
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -707,8 +707,10 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||
if (packedInput) {
|
||||
SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
|
||||
SmallVector<Value> sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
|
||||
SmallVector<Value> sourceCoreIdValues =
|
||||
createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
|
||||
SmallVector<Value> targetCoreIdValues =
|
||||
createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
|
||||
spatial::SpatChannelSendTensorOp::create(
|
||||
rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput);
|
||||
for (auto op : run.ops)
|
||||
|
||||
@@ -39,11 +39,11 @@ struct ComputeGraph {
|
||||
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
||||
};
|
||||
|
||||
ComputeGraph buildComputeGraph(mlir::Operation *entryOp);
|
||||
bool verifyAcyclic(const ComputeGraph &graph);
|
||||
ComputeGraph buildComputeGraph(mlir::Operation* entryOp);
|
||||
bool verifyAcyclic(const ComputeGraph& graph);
|
||||
|
||||
Weight getComputeInstanceWeight(const ComputeInstance &instance);
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance);
|
||||
Weight getComputeInstanceWeight(const ComputeInstance& instance);
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -11,11 +11,11 @@ namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct ComputeInstance {
|
||||
mlir::Operation *op = nullptr;
|
||||
mlir::Operation* op = nullptr;
|
||||
uint32_t laneStart = 0;
|
||||
uint32_t laneCount = 1;
|
||||
|
||||
bool operator==(const ComputeInstance &other) const {
|
||||
bool operator==(const ComputeInstance& other) const {
|
||||
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
|
||||
}
|
||||
};
|
||||
@@ -29,16 +29,15 @@ namespace llvm {
|
||||
template <>
|
||||
struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> {
|
||||
static onnx_mlir::spatial::ComputeInstance getEmptyKey() {
|
||||
return {DenseMapInfo<mlir::Operation *>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
|
||||
return {DenseMapInfo<mlir::Operation*>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static onnx_mlir::spatial::ComputeInstance getTombstoneKey() {
|
||||
return {DenseMapInfo<mlir::Operation *>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
||||
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance &value) {
|
||||
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance& value) {
|
||||
return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
|
||||
}
|
||||
static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs,
|
||||
const onnx_mlir::spatial::ComputeInstance &rhs) {
|
||||
static bool isEqual(const onnx_mlir::spatial::ComputeInstance& lhs, const onnx_mlir::spatial::ComputeInstance& rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
|
||||
+7
-7
@@ -27,15 +27,15 @@ ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex)
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
|
||||
|
||||
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value,
|
||||
const ComputeInstance *consumerInstance = nullptr);
|
||||
const ComputeInstance* consumerInstance = nullptr);
|
||||
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value,
|
||||
const ComputeInstance *consumerInstance = nullptr);
|
||||
const ComputeInstance* consumerInstance = nullptr);
|
||||
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance);
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance);
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance);
|
||||
llvm::SmallVector<mlir::Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance);
|
||||
mlir::Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance);
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance& instance);
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance& instance);
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceOutputValues(const ComputeInstance& instance);
|
||||
llvm::SmallVector<mlir::Type, 4> getComputeInstanceOutputTypes(const ComputeInstance& instance);
|
||||
mlir::Block& getComputeInstanceTemplateBlock(const ComputeInstance& instance);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -10,8 +10,8 @@
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include "DcpScheduler.hpp"
|
||||
#include "../DCPGraph/Graph.hpp"
|
||||
#include "DcpScheduler.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
@@ -47,7 +47,7 @@ struct WindowScheduleResult {
|
||||
size_t maxMergeGroupSize = 0;
|
||||
};
|
||||
|
||||
size_t getSchedulingCpuBudget(const DcpScheduleOptions &options) {
|
||||
size_t getSchedulingCpuBudget(const DcpScheduleOptions& options) {
|
||||
if (options.processorCount > 0)
|
||||
return options.processorCount;
|
||||
return std::numeric_limits<size_t>::max();
|
||||
@@ -72,7 +72,7 @@ std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
||||
for (auto [key, weight] : edgeWeights)
|
||||
aggregatedEdges.push_back(
|
||||
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
||||
llvm::sort(aggregatedEdges, [](const IndexedEdge &lhs, const IndexedEdge &rhs) {
|
||||
llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) {
|
||||
if (std::get<0>(lhs) != std::get<0>(rhs))
|
||||
return std::get<0>(lhs) < std::get<0>(rhs);
|
||||
return std::get<1>(lhs) < std::get<1>(rhs);
|
||||
@@ -80,7 +80,7 @@ std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
||||
return aggregatedEdges;
|
||||
}
|
||||
|
||||
VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) {
|
||||
VirtualGraph buildInitialVirtualGraph(const ComputeGraph& graph) {
|
||||
VirtualGraph virtualGraph;
|
||||
virtualGraph.nodes.reserve(graph.nodes.size());
|
||||
for (auto [index, node] : llvm::enumerate(graph.nodes)) {
|
||||
@@ -93,14 +93,14 @@ VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) {
|
||||
|
||||
std::vector<IndexedEdge> edges;
|
||||
edges.reserve(graph.edges.size());
|
||||
for (const ComputeGraphEdge &edge : graph.edges)
|
||||
for (const ComputeGraphEdge& edge : graph.edges)
|
||||
edges.push_back(
|
||||
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
|
||||
virtualGraph.edges = aggregateEdges(edges);
|
||||
return virtualGraph;
|
||||
}
|
||||
|
||||
TimingInfo computeTiming(const VirtualGraph &graph) {
|
||||
TimingInfo computeTiming(const VirtualGraph& graph) {
|
||||
TimingInfo timing;
|
||||
size_t nodeCount = graph.nodes.size();
|
||||
timing.aest.assign(nodeCount, 0);
|
||||
@@ -122,7 +122,7 @@ TimingInfo computeTiming(const VirtualGraph &graph) {
|
||||
}
|
||||
|
||||
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
|
||||
const VirtualNode &node = graph.nodes[nodeIndex];
|
||||
const VirtualNode& node = graph.nodes[nodeIndex];
|
||||
if (!node.originalNodeIndices.empty())
|
||||
return node.originalNodeIndices.front();
|
||||
return nodeIndex;
|
||||
@@ -181,7 +181,7 @@ TimingInfo computeTiming(const VirtualGraph &graph) {
|
||||
return timing;
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph &graph) {
|
||||
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph& graph) {
|
||||
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
(void) weight;
|
||||
@@ -191,14 +191,14 @@ std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph &gr
|
||||
adjacency[startIndex].push_back(endIndex);
|
||||
adjacency[endIndex].push_back(startIndex);
|
||||
}
|
||||
for (auto &neighbours : adjacency) {
|
||||
for (auto& neighbours : adjacency) {
|
||||
llvm::sort(neighbours);
|
||||
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
|
||||
}
|
||||
return adjacency;
|
||||
}
|
||||
|
||||
std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const TimingInfo &timing, size_t windowSize) {
|
||||
std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) {
|
||||
std::vector<size_t> ranked(timing.aest.size());
|
||||
std::iota(ranked.begin(), ranked.end(), 0);
|
||||
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
|
||||
@@ -240,7 +240,7 @@ std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const Timing
|
||||
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
|
||||
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
|
||||
|
||||
auto addToWindow = [&](size_t node, const std::vector<char> &eligible) {
|
||||
auto addToWindow = [&](size_t node, const std::vector<char>& eligible) {
|
||||
if (inWindow[node])
|
||||
return;
|
||||
inWindow[node] = true;
|
||||
@@ -288,7 +288,7 @@ std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const Timing
|
||||
return selected;
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph &graph, const std::vector<int64_t> &nodeToWindowIndex) {
|
||||
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
|
||||
std::vector<IndexedEdge> windowEdges;
|
||||
windowEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
@@ -301,10 +301,10 @@ std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph &graph, const std::
|
||||
return aggregateEdges(windowEdges);
|
||||
}
|
||||
|
||||
WindowScheduleResult scheduleWindow(const VirtualGraph &graph,
|
||||
WindowScheduleResult scheduleWindow(const VirtualGraph& graph,
|
||||
llvm::ArrayRef<size_t> selectedNodes,
|
||||
const DcpScheduleOptions &options,
|
||||
mlir::MLIRContext *context) {
|
||||
const DcpScheduleOptions& options,
|
||||
mlir::MLIRContext* context) {
|
||||
std::vector<Weight> windowWeights;
|
||||
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||
std::vector<int64_t> windowNodeOrderKeys;
|
||||
@@ -338,17 +338,17 @@ WindowScheduleResult scheduleWindow(const VirtualGraph &graph,
|
||||
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
|
||||
std::vector<size_t> mergeGroup;
|
||||
mergeGroup.reserve(scheduledTasks.size());
|
||||
for (const auto &task : scheduledTasks)
|
||||
for (const auto& task : scheduledTasks)
|
||||
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
||||
result.mergeGroups.push_back(std::move(mergeGroup));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool coarsenGraph(const VirtualGraph &graph,
|
||||
bool coarsenGraph(const VirtualGraph& graph,
|
||||
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
||||
VirtualGraph &coarsenedGraph,
|
||||
std::vector<size_t> &oldToNewNode) {
|
||||
VirtualGraph& coarsenedGraph,
|
||||
std::vector<size_t>& oldToNewNode) {
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
std::vector<size_t> topologicalRank(graph.nodes.size());
|
||||
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
|
||||
@@ -358,7 +358,7 @@ bool coarsenGraph(const VirtualGraph &graph,
|
||||
|
||||
std::vector<std::vector<size_t>> orderedMergeGroups;
|
||||
orderedMergeGroups.reserve(mergeGroups.size());
|
||||
for (const auto &mergeGroup : mergeGroups) {
|
||||
for (const auto& mergeGroup : mergeGroups) {
|
||||
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
|
||||
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
|
||||
if (topologicalRank[lhs] != topologicalRank[rhs])
|
||||
@@ -395,7 +395,7 @@ bool coarsenGraph(const VirtualGraph &graph,
|
||||
continue;
|
||||
}
|
||||
|
||||
auto &newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
|
||||
auto& newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
|
||||
if (newNodeIndex.has_value()) {
|
||||
oldToNewNode[nodeIndex] = *newNodeIndex;
|
||||
continue;
|
||||
@@ -403,8 +403,9 @@ bool coarsenGraph(const VirtualGraph &graph,
|
||||
|
||||
VirtualNode mergedNode;
|
||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||
const VirtualNode &memberNode = graph.nodes[memberIndex];
|
||||
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), memberNode.originalNodeIndices.end());
|
||||
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
||||
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(),
|
||||
memberNode.originalNodeIndices.end());
|
||||
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
||||
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
||||
}
|
||||
@@ -437,7 +438,7 @@ bool coarsenGraph(const VirtualGraph &graph,
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &options) {
|
||||
size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions& options) {
|
||||
size_t windowSize = std::min(options.criticalWindowSize, nodeCount);
|
||||
CPU maxCpuCount = std::max<CPU>(1, static_cast<CPU>(getSchedulingCpuBudget(options)));
|
||||
if (nodeCount > static_cast<size_t>(maxCpuCount))
|
||||
@@ -445,7 +446,7 @@ size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &op
|
||||
return windowSize;
|
||||
}
|
||||
|
||||
void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) {
|
||||
void assignFeasibleAest(const ComputeGraph& graph, MergeScheduleResult& result) {
|
||||
llvm::DenseMap<ComputeInstance, size_t> nodeIndexByInstance;
|
||||
nodeIndexByInstance.reserve(graph.nodes.size());
|
||||
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
|
||||
@@ -458,7 +459,7 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result)
|
||||
|
||||
std::vector<std::vector<ScheduledEdge>> scheduledChildren(graph.nodes.size());
|
||||
std::vector<size_t> incomingEdgeCount(graph.nodes.size(), 0);
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
for (const ComputeGraphEdge& edge : graph.edges) {
|
||||
const ComputeInstance sourceInstance = graph.nodes[edge.source].instance;
|
||||
const ComputeInstance targetInstance = graph.nodes[edge.target].instance;
|
||||
const size_t sourceCpu = result.computeToCpuMap.lookup(sourceInstance);
|
||||
@@ -473,15 +474,15 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result)
|
||||
}
|
||||
|
||||
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||
for (const ComputeGraphNode &node : graph.nodes) {
|
||||
for (const ComputeGraphNode& node : graph.nodes) {
|
||||
size_t cpu = result.computeToCpuMap.lookup(node.instance);
|
||||
size_t slot = result.computeToCpuSlotMap.lookup(node.instance);
|
||||
tasksByCpu[cpu].push_back({slot, nodeIndexByInstance.lookup(node.instance)});
|
||||
}
|
||||
|
||||
for (auto &entry : tasksByCpu) {
|
||||
auto &scheduledTasks = entry.second;
|
||||
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
|
||||
for (auto& entry : tasksByCpu) {
|
||||
auto& scheduledTasks = entry.second;
|
||||
llvm::sort(scheduledTasks, [](const auto& lhs, const auto& rhs) {
|
||||
if (lhs.first != rhs.first)
|
||||
return lhs.first < rhs.first;
|
||||
return lhs.second < rhs.second;
|
||||
@@ -512,7 +513,7 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result)
|
||||
readyNodes.pop();
|
||||
processedNodeCount++;
|
||||
|
||||
for (const ScheduledEdge &edge : scheduledChildren[sourceIndex]) {
|
||||
for (const ScheduledEdge& edge : scheduledChildren[sourceIndex]) {
|
||||
startTimes[edge.target] = std::max(startTimes[edge.target], addOrMax(startTimes[sourceIndex], edge.delay));
|
||||
assert(incomingEdgeCount[edge.target] > 0 && "scheduled incoming edge count underflow");
|
||||
incomingEdgeCount[edge.target]--;
|
||||
@@ -528,7 +529,7 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result)
|
||||
result.computeToAestMap[node.instance] = startTimes[nodeIndex];
|
||||
}
|
||||
|
||||
MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const ComputeGraph &originalGraph) {
|
||||
MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph& graph, const ComputeGraph& originalGraph) {
|
||||
MergeScheduleResult result;
|
||||
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
@@ -542,7 +543,7 @@ MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const
|
||||
|
||||
std::vector<size_t> originalNodeToCpu(originalGraph.nodes.size(), 0);
|
||||
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
|
||||
const VirtualNode &virtualNode = graph.nodes[virtualNodeIndex];
|
||||
const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
|
||||
for (size_t originalIndex : virtualNode.originalNodeIndices)
|
||||
originalNodeToCpu[originalIndex] = cpu;
|
||||
}
|
||||
@@ -556,17 +557,17 @@ MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const
|
||||
result.computeToCpuSlotMap[node.instance] = nextCpuSlot[cpu]++;
|
||||
result.cpuToLastComputeMap[cpu] = node.instance;
|
||||
}
|
||||
for (const auto &[cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||
result.isLastComputeOfCpu.insert(lastCompute);
|
||||
assignFeasibleAest(originalGraph, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const ComputeGraph &graph) {
|
||||
MergeScheduleResult buildResultFromScheduledGraph(GraphDCP& graphDCP, const ComputeGraph& graph) {
|
||||
MergeScheduleResult result;
|
||||
result.dominanceOrderCompute.reserve(graph.nodes.size());
|
||||
for (const ComputeGraphNode &node : graph.nodes)
|
||||
for (const ComputeGraphNode& node : graph.nodes)
|
||||
result.dominanceOrderCompute.push_back(node.instance);
|
||||
|
||||
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
||||
@@ -589,7 +590,8 @@ MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const Comp
|
||||
return result;
|
||||
}
|
||||
|
||||
MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
|
||||
MergeScheduleResult
|
||||
runLegacyDcp(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context) {
|
||||
llvm::SmallVector<Weight> nodeWeights;
|
||||
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
||||
llvm::SmallVector<int64_t> nodeOrderKeys;
|
||||
@@ -599,12 +601,12 @@ MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOpt
|
||||
nodeOrderKeys.reserve(graph.nodes.size());
|
||||
edges.reserve(graph.edges.size());
|
||||
|
||||
for (const ComputeGraphNode &node : graph.nodes) {
|
||||
for (const ComputeGraphNode& node : graph.nodes) {
|
||||
nodeWeights.push_back(node.weight);
|
||||
nodeCrossbarUsage.push_back(node.crossbarUsage);
|
||||
nodeOrderKeys.push_back(static_cast<int64_t>(node.originalOrder));
|
||||
}
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
for (const ComputeGraphEdge& edge : graph.edges) {
|
||||
edges.push_back(
|
||||
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
|
||||
}
|
||||
@@ -617,11 +619,11 @@ MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOpt
|
||||
return buildResultFromScheduledGraph(graphDCP, graph);
|
||||
}
|
||||
|
||||
bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOptions &options) {
|
||||
bool needsExactScheduledBatches(const ComputeGraph& graph, const DcpScheduleOptions& options) {
|
||||
if (options.processorCount == 0 || !options.allowFallbackForAutoCoreCount)
|
||||
return false;
|
||||
size_t schedulingCpuBudget = getSchedulingCpuBudget(options);
|
||||
return llvm::any_of(graph.nodes, [&](const ComputeGraphNode &node) {
|
||||
return llvm::any_of(graph.nodes, [&](const ComputeGraphNode& node) {
|
||||
auto batch = dyn_cast<SpatComputeBatch>(node.instance.op);
|
||||
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
|
||||
});
|
||||
@@ -630,7 +632,7 @@ bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOpti
|
||||
} // namespace
|
||||
|
||||
MergeScheduleResult
|
||||
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
|
||||
runDcpScheduler(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context) {
|
||||
if (needsExactScheduledBatches(graph, options))
|
||||
return runLegacyDcp(graph, options, context);
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ struct DcpScheduleOptions {
|
||||
};
|
||||
|
||||
MergeScheduleResult
|
||||
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context);
|
||||
runDcpScheduler(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
+20
-25
@@ -1,13 +1,13 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "../DCPGraph/DCPAnalysis.hpp"
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "DcpScheduler.hpp"
|
||||
#include "MergeSchedulingAnalysis.hpp"
|
||||
#include "PeftScheduler.hpp"
|
||||
@@ -20,15 +20,13 @@ namespace {
|
||||
|
||||
MergeSchedulerKind getSchedulerKind() {
|
||||
switch (pimMergeScheduler.getValue()) {
|
||||
case MergeSchedulerPeft:
|
||||
return MergeSchedulerKind::Peft;
|
||||
case MergeSchedulerDcp:
|
||||
return MergeSchedulerKind::Dcp;
|
||||
case MergeSchedulerPeft: return MergeSchedulerKind::Peft;
|
||||
case MergeSchedulerDcp: return MergeSchedulerKind::Dcp;
|
||||
}
|
||||
llvm_unreachable("unknown merge scheduler kind");
|
||||
}
|
||||
|
||||
void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result, CrossbarUsage crossbarCapacity) {
|
||||
void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result, CrossbarUsage crossbarCapacity) {
|
||||
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
|
||||
|
||||
@@ -45,9 +43,9 @@ void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result
|
||||
{result.computeToCpuSlotMap.lookup(instance), nodeIndex});
|
||||
}
|
||||
|
||||
for (auto &entry : tasksByCpu) {
|
||||
auto &scheduledTasks = entry.second;
|
||||
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
|
||||
for (auto& entry : tasksByCpu) {
|
||||
auto& scheduledTasks = entry.second;
|
||||
llvm::sort(scheduledTasks, [](const auto& lhs, const auto& rhs) {
|
||||
if (lhs.first != rhs.first)
|
||||
return lhs.first < rhs.first;
|
||||
return lhs.second < rhs.second;
|
||||
@@ -70,7 +68,7 @@ void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result
|
||||
llvm::report_fatal_error("merge scheduling: missing last-compute marker");
|
||||
}
|
||||
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
for (const ComputeGraphEdge& edge : graph.edges) {
|
||||
const ComputeInstance source = graph.nodes[edge.source].instance;
|
||||
const ComputeInstance target = graph.nodes[edge.target].instance;
|
||||
const size_t sourceCpu = result.computeToCpuMap.lookup(source);
|
||||
@@ -97,8 +95,8 @@ void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result
|
||||
|
||||
} // namespace
|
||||
|
||||
MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation *op)
|
||||
: entryOp(op) {
|
||||
MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation* op)
|
||||
: entryOp(op) {
|
||||
result = run();
|
||||
}
|
||||
|
||||
@@ -115,20 +113,17 @@ MergeScheduleResult MergeSchedulingAnalysis::run() {
|
||||
|
||||
MergeScheduleResult schedule;
|
||||
if (options.kind == MergeSchedulerKind::Peft) {
|
||||
schedule = runPeftScheduler(
|
||||
graph,
|
||||
PeftScheduleOptions {options.processorCount, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
|
||||
entryOp->getContext()});
|
||||
schedule = runPeftScheduler(graph,
|
||||
PeftScheduleOptions {options.processorCount,
|
||||
static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
|
||||
entryOp->getContext()});
|
||||
}
|
||||
else {
|
||||
schedule = runDcpScheduler(
|
||||
graph,
|
||||
DcpScheduleOptions {
|
||||
options.processorCount,
|
||||
dcpCriticalWindowSize.getValue(),
|
||||
options.allowDcpFallbackForAutoCoreCount
|
||||
},
|
||||
entryOp->getContext());
|
||||
schedule = runDcpScheduler(graph,
|
||||
DcpScheduleOptions {options.processorCount,
|
||||
dcpCriticalWindowSize.getValue(),
|
||||
options.allowDcpFallbackForAutoCoreCount},
|
||||
entryOp->getContext());
|
||||
}
|
||||
|
||||
verifySchedule(graph, schedule, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()));
|
||||
|
||||
+3
-3
@@ -22,11 +22,11 @@ struct MergeSchedulingOptions {
|
||||
|
||||
class MergeSchedulingAnalysis {
|
||||
public:
|
||||
explicit MergeSchedulingAnalysis(mlir::Operation *op);
|
||||
MergeScheduleResult &getResult() { return result; }
|
||||
explicit MergeSchedulingAnalysis(mlir::Operation* op);
|
||||
MergeScheduleResult& getResult() { return result; }
|
||||
|
||||
private:
|
||||
mlir::Operation *entryOp = nullptr;
|
||||
mlir::Operation* entryOp = nullptr;
|
||||
MergeScheduleResult result;
|
||||
|
||||
MergeScheduleResult run();
|
||||
|
||||
@@ -11,10 +11,10 @@ namespace spatial {
|
||||
struct PeftScheduleOptions {
|
||||
size_t processorCount = 0;
|
||||
CrossbarUsage crossbarCapacity = 0;
|
||||
mlir::MLIRContext *context = nullptr;
|
||||
mlir::MLIRContext* context = nullptr;
|
||||
};
|
||||
|
||||
MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftScheduleOptions &options);
|
||||
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -8,8 +8,8 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -211,8 +211,9 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
||||
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
||||
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
||||
(void) withScalarCoreFromBatchLane(
|
||||
coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) { return verifyCoreOperands(scalarCore, diagnostics); });
|
||||
(void) withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) {
|
||||
return verifyCoreOperands(scalarCore, diagnostics);
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user