automatic code reformat
This commit is contained in:
@@ -13,7 +13,8 @@
|
|||||||
namespace onnx_mlir::pim {
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
struct CappedDiagnosticReporter {
|
struct CappedDiagnosticReporter {
|
||||||
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
|
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8)
|
||||||
|
: maxReportedFailures(maxReportedFailures) {}
|
||||||
|
|
||||||
template <typename EmitFn>
|
template <typename EmitFn>
|
||||||
void report(mlir::Operation* op, EmitFn&& emit) {
|
void report(mlir::Operation* op, EmitFn&& emit) {
|
||||||
@@ -24,8 +25,7 @@ struct CappedDiagnosticReporter {
|
|||||||
|
|
||||||
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
|
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
|
||||||
if (numFailures > maxReportedFailures)
|
if (numFailures > maxReportedFailures)
|
||||||
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
|
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
|
||||||
<< failureDescription;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool hasFailure() const { return numFailures != 0; }
|
bool hasFailure() const { return numFailures != 0; }
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
||||||
|
|
||||||
#include "llvm/Support/ErrorHandling.h"
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
#define DEBUG_TYPE "PimCompilerOptions"
|
#define DEBUG_TYPE "PimCompilerOptions"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -15,8 +15,8 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
|
|||||||
llvm::cl::init(EmitPimCodegen),
|
llvm::cl::init(EmitPimCodegen),
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
|
llvm::cl::opt<PimMergeSchedulerType>
|
||||||
"pim-merge-scheduler",
|
pimMergeScheduler("pim-merge-scheduler",
|
||||||
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
||||||
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
||||||
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
|
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
|
||||||
|
|||||||
@@ -99,15 +99,17 @@ auto createSpatCompute(RewriterT& rewriter,
|
|||||||
|
|
||||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||||
detail::invokeWithValues(
|
detail::invokeWithValues(std::forward<BodyFn>(body),
|
||||||
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
|
detail::getInputBlockArgs(block, weights.size()),
|
||||||
|
std::make_index_sequence<NumInputs> {});
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
return computeOp;
|
return computeOp;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
auto bodyResult = detail::invokeWithValues(
|
auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
|
||||||
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
|
detail::getInputBlockArgs(block, weights.size()),
|
||||||
|
std::make_index_sequence<NumInputs> {});
|
||||||
if (mlir::failed(bodyResult)) {
|
if (mlir::failed(bodyResult)) {
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
rewriter.eraseOp(computeOp);
|
rewriter.eraseOp(computeOp);
|
||||||
|
|||||||
@@ -423,8 +423,11 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
SmallVector<Value> vmmOutputs;
|
SmallVector<Value> vmmOutputs;
|
||||||
vmmOutputs.reserve(aHSlices[coreId].size());
|
vmmOutputs.reserve(aHSlices[coreId].size());
|
||||||
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
|
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
|
||||||
vmmOutputs.push_back(spatial::SpatVMMOp::create(
|
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter,
|
||||||
rewriter, gemmLoc, currOutHSliceType, computeOp.getWeightArgument(aHSliceId), computeOp.getInputArgument(aHSliceId)));
|
gemmLoc,
|
||||||
|
currOutHSliceType,
|
||||||
|
computeOp.getWeightArgument(aHSliceId),
|
||||||
|
computeOp.getInputArgument(aHSliceId)));
|
||||||
if (vmmOutputs.empty()) {
|
if (vmmOutputs.empty()) {
|
||||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||||
return failure();
|
return failure();
|
||||||
@@ -579,8 +582,8 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||||
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
|
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
|
||||||
tensor::ParallelInsertSliceOp::create(rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes,
|
tensor::ParallelInsertSliceOp::create(
|
||||||
unitStrides);
|
rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, unitStrides);
|
||||||
rewriter.setInsertionPointAfter(batchOp);
|
rewriter.setInsertionPointAfter(batchOp);
|
||||||
|
|
||||||
rewriter.replaceOp(gemmOp, batchOp.getResults());
|
rewriter.replaceOp(gemmOp, batchOp.getResults());
|
||||||
|
|||||||
@@ -38,23 +38,16 @@ static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t
|
|||||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value collapseBatchDims(Value value,
|
static Value
|
||||||
int64_t batchSize,
|
collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
|
||||||
int64_t rows,
|
|
||||||
int64_t cols,
|
|
||||||
PatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
auto type = cast<RankedTensorType>(value.getType());
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
if (type.getRank() == 2 || type.getRank() == 3)
|
if (type.getRank() == 2 || type.getRank() == 3)
|
||||||
return value;
|
return value;
|
||||||
|
|
||||||
auto collapsedType =
|
auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||||
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
|
||||||
SmallVector<ReassociationIndices> reassociation = {
|
|
||||||
ReassociationIndices {},
|
|
||||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
||||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
|
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}};
|
||||||
};
|
|
||||||
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
||||||
reassociation.front().push_back(dim);
|
reassociation.front().push_back(dim);
|
||||||
|
|
||||||
@@ -72,19 +65,14 @@ static Value collapseBatchDims(Value value,
|
|||||||
return collapseCompute.getResult(0);
|
return collapseCompute.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value expandBatchDims(Value value,
|
static Value
|
||||||
RankedTensorType outputType,
|
expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) {
|
||||||
size_t batchRank,
|
|
||||||
PatternRewriter& rewriter,
|
|
||||||
Location loc) {
|
|
||||||
if (cast<RankedTensorType>(value.getType()) == outputType)
|
if (cast<RankedTensorType>(value.getType()) == outputType)
|
||||||
return value;
|
return value;
|
||||||
|
|
||||||
SmallVector<ReassociationIndices> reassociation = {
|
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
|
||||||
ReassociationIndices {},
|
|
||||||
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
||||||
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
|
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}};
|
||||||
};
|
|
||||||
for (size_t dim = 0; dim < batchRank; ++dim)
|
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||||
|
|
||||||
|
|||||||
@@ -58,24 +58,21 @@ static Value buildNearestResizeLoop(Value input,
|
|||||||
|
|
||||||
Value outputC = channelLoop.getInductionVar();
|
Value outputC = channelLoop.getInductionVar();
|
||||||
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
|
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
|
||||||
Value inputC =
|
Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
||||||
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
|
||||||
|
|
||||||
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
|
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
|
||||||
rewriter.setInsertionPointToStart(heightLoop.getBody());
|
rewriter.setInsertionPointToStart(heightLoop.getBody());
|
||||||
|
|
||||||
Value outputH = heightLoop.getInductionVar();
|
Value outputH = heightLoop.getInductionVar();
|
||||||
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
|
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
|
||||||
Value inputH =
|
Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
||||||
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
|
||||||
|
|
||||||
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
|
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
|
||||||
rewriter.setInsertionPointToStart(widthLoop.getBody());
|
rewriter.setInsertionPointToStart(widthLoop.getBody());
|
||||||
|
|
||||||
Value outputW = widthLoop.getInductionVar();
|
Value outputW = widthLoop.getInductionVar();
|
||||||
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
|
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
|
||||||
Value inputW =
|
Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
||||||
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
|
||||||
|
|
||||||
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
|
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
|
||||||
Value inputSlice =
|
Value inputSlice =
|
||||||
@@ -114,8 +111,8 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|
|||||||
|
|
||||||
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
||||||
|| resizeOp.getNearestMode() != "floor")
|
|| resizeOp.getNearestMode() != "floor")
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(resizeOp,
|
||||||
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
|
"resize lowering currently supports only nearest + asymmetric + floor.");
|
||||||
|
|
||||||
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
||||||
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
||||||
|
|||||||
@@ -94,8 +94,8 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
|||||||
}
|
}
|
||||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||||
auto* newBlock =
|
auto* newBlock = rewriter.createBlock(
|
||||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||||
rewriter.setInsertionPointToStart(newBlock);
|
rewriter.setInsertionPointToStart(newBlock);
|
||||||
@@ -228,7 +228,8 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
|||||||
mapper.map(oldArg, *clonedValue);
|
mapper.map(oldArg, *clonedValue);
|
||||||
}
|
}
|
||||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults()))
|
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults()))
|
||||||
mapper.map(compute.getOutputArgument(resultIndex), newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
mapper.map(compute.getOutputArgument(resultIndex),
|
||||||
|
newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
||||||
|
|
||||||
for (Operation& op : oldBlock)
|
for (Operation& op : oldBlock)
|
||||||
rewriter.clone(op, mapper);
|
rewriter.clone(op, mapper);
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
|
|
||||||
@@ -141,8 +141,8 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
|
|||||||
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
|
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
totalOffset = totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult()
|
totalOffset =
|
||||||
: scaledOffset;
|
totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() : scaledOffset;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!totalOffset)
|
if (!totalOffset)
|
||||||
|
|||||||
@@ -233,10 +233,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
|||||||
if (!computeOp.getWeights().empty())
|
if (!computeOp.getWeights().empty())
|
||||||
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
auto coreOp = PimCoreOp::create(rewriter,
|
auto coreOp = PimCoreOp::create(
|
||||||
loc,
|
rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||||
ValueRange(computeWeights),
|
|
||||||
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
|
||||||
rewriter.setInsertionPointToStart(&block);
|
rewriter.setInsertionPointToStart(&block);
|
||||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||||
|
|||||||
@@ -217,8 +217,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
|
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
|
||||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(
|
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(patterns.getContext());
|
||||||
patterns.getContext());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -546,11 +546,8 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
|||||||
return ReturnPathLoweringResult::NotReturnPath;
|
return ReturnPathLoweringResult::NotReturnPath;
|
||||||
}
|
}
|
||||||
|
|
||||||
raptor::SpatialToPimPass::ReturnPathLoweringResult
|
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
|
||||||
raptor::SpatialToPimPass::lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
|
||||||
OpResult result,
|
|
||||||
Value yieldValue,
|
|
||||||
IRRewriter& rewriter) {
|
|
||||||
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
|
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -39,13 +39,10 @@ private:
|
|||||||
size_t coreId = 0;
|
size_t coreId = 0;
|
||||||
llvm::SmallVector<mlir::Operation*> operationsToRemove;
|
llvm::SmallVector<mlir::Operation*> operationsToRemove;
|
||||||
|
|
||||||
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp,
|
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||||
mlir::IRRewriter& rewriter);
|
mlir::LogicalResult
|
||||||
mlir::LogicalResult lowerComputeOp(spatial::SpatCompute computeOp,
|
lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder);
|
||||||
mlir::IRRewriter& rewriter,
|
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter);
|
||||||
mlir::OperationFolder& constantFolder);
|
|
||||||
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
|
||||||
mlir::IRRewriter& rewriter);
|
|
||||||
|
|
||||||
enum class ReturnPathLoweringResult {
|
enum class ReturnPathLoweringResult {
|
||||||
Handled,
|
Handled,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|||||||
@@ -56,7 +56,8 @@ static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<O
|
|||||||
return parser.parseRParen();
|
return parser.parseRParen();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
static void
|
||||||
|
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
||||||
printCompressedValueList(printer, arguments, delimiter);
|
printCompressedValueList(printer, arguments, delimiter);
|
||||||
printer << " = ";
|
printer << " = ";
|
||||||
printCompressedValueList(printer, operands, delimiter);
|
printCompressedValueList(printer, operands, delimiter);
|
||||||
@@ -82,10 +83,8 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
|
|||||||
|
|
||||||
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
||||||
switch (currentDelimiter) {
|
switch (currentDelimiter) {
|
||||||
case ListDelimiter::Paren:
|
case ListDelimiter::Paren: return parser.parseRParen();
|
||||||
return parser.parseRParen();
|
case ListDelimiter::Square: return parser.parseRSquare();
|
||||||
case ListDelimiter::Square:
|
|
||||||
return parser.parseRSquare();
|
|
||||||
}
|
}
|
||||||
llvm_unreachable("unsupported delimiter");
|
llvm_unreachable("unsupported delimiter");
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
|
||||||
#include "mlir/IR/Block.h"
|
#include "mlir/IR/Block.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
#include "mlir/IR/Diagnostics.h"
|
#include "mlir/IR/Diagnostics.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
@@ -66,8 +66,8 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
|
|||||||
|| isExplicitHostOperand(op, operand.getOperandNumber()))
|
|| isExplicitHostOperand(op, operand.getOperandNumber()))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
InFlightDiagnostic diagnostic =
|
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
||||||
ownerOp->emitOpError() << kind << " body may only directly reference external constants";
|
<< kind << " body may only directly reference external constants";
|
||||||
diagnostic.attachNote(op->getLoc())
|
diagnostic.attachNote(op->getLoc())
|
||||||
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
@@ -153,10 +153,9 @@ LogicalResult PimCoreOp::verify() {
|
|||||||
Block& block = getBody().front();
|
Block& block = getBody().front();
|
||||||
if (block.getNumArguments() != getWeights().size())
|
if (block.getNumArguments() != getWeights().size())
|
||||||
return emitError("core body must have one block argument per weight");
|
return emitError("core body must have one block argument per weight");
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||||
return emitError("core weight block argument types must match weight operand types exactly");
|
return emitError("core weight block argument types must match weight operand types exactly");
|
||||||
}
|
|
||||||
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
|
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,14 +168,12 @@ LogicalResult PimCoreBatchOp::verify() {
|
|||||||
return emitError("core_batch body must have lane, weight, and input block arguments");
|
return emitError("core_batch body must have lane, weight, and input block arguments");
|
||||||
if (!getLaneArgument().getType().isIndex())
|
if (!getLaneArgument().getType().isIndex())
|
||||||
return emitError("core_batch first block argument must have index type");
|
return emitError("core_batch first block argument must have index type");
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||||
return emitError("core_batch weight block argument types must match weight operand types exactly");
|
return emitError("core_batch weight block argument types must match weight operand types exactly");
|
||||||
}
|
for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
|
||||||
if (getInputArgument(inputIndex).getType() != input.getType())
|
if (getInputArgument(inputIndex).getType() != input.getType())
|
||||||
return emitError("core_batch input block argument types must match input operand types exactly");
|
return emitError("core_batch input block argument types must match input operand types exactly");
|
||||||
}
|
|
||||||
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
|
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -50,11 +50,10 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
|
|||||||
pendingValues.push_back(result);
|
pendingValues.push_back(result);
|
||||||
|
|
||||||
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
|
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
|
||||||
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
|
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs()))
|
||||||
if (initArg == value)
|
if (initArg == value)
|
||||||
pendingValues.push_back(forOp.getResult(index));
|
pendingValues.push_back(forOp.getResult(index));
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
||||||
for (OpResult result : user->getResults()) {
|
for (OpResult result : user->getResults()) {
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -65,9 +65,7 @@ void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
|
|||||||
|
|
||||||
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
|
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
|
||||||
|
|
||||||
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() {
|
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
|
||||||
return getRegion().front().getOperations();
|
|
||||||
}
|
|
||||||
|
|
||||||
void SpatialDialect::initialize() {
|
void SpatialDialect::initialize() {
|
||||||
addTypes<
|
addTypes<
|
||||||
|
|||||||
@@ -104,16 +104,13 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
|
|||||||
return failure();
|
return failure();
|
||||||
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
||||||
switch (currentDelimiter) {
|
switch (currentDelimiter) {
|
||||||
case ListDelimiter::Paren:
|
case ListDelimiter::Paren: return parser.parseRParen();
|
||||||
return parser.parseRParen();
|
case ListDelimiter::Square: return parser.parseRSquare();
|
||||||
case ListDelimiter::Square:
|
|
||||||
return parser.parseRSquare();
|
|
||||||
}
|
}
|
||||||
llvm_unreachable("unsupported delimiter");
|
llvm_unreachable("unsupported delimiter");
|
||||||
};
|
};
|
||||||
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) {
|
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||||
return failure();
|
return failure();
|
||||||
}
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -273,8 +273,8 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
|
|||||||
|
|
||||||
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
||||||
<< kind << " body may only directly reference external constants";
|
<< kind << " body may only directly reference external constants";
|
||||||
diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber()
|
diagnostic.attachNote(op->getLoc())
|
||||||
<< " is used by " << op->getName();
|
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -457,14 +457,12 @@ LogicalResult SpatCompute::verify() {
|
|||||||
if (block.getNumArguments() != expectedArgCount)
|
if (block.getNumArguments() != expectedArgCount)
|
||||||
return emitError("compute body must have weight and input block arguments");
|
return emitError("compute body must have weight and input block arguments");
|
||||||
|
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||||
return emitError("compute weight block argument types must match weight operand types exactly");
|
return emitError("compute weight block argument types must match weight operand types exactly");
|
||||||
}
|
for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
|
||||||
if (getInputArgument(inputIndex).getType() != input.getType())
|
if (getInputArgument(inputIndex).getType() != input.getType())
|
||||||
return emitError("compute input block argument types must match input operand types exactly");
|
return emitError("compute input block argument types must match input operand types exactly");
|
||||||
}
|
|
||||||
|
|
||||||
if (block.mightHaveTerminator()) {
|
if (block.mightHaveTerminator()) {
|
||||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||||
@@ -582,10 +580,9 @@ LogicalResult SpatComputeBatch::verify() {
|
|||||||
if (!getLaneArgument().getType().isIndex())
|
if (!getLaneArgument().getType().isIndex())
|
||||||
return emitError("compute_batch first block argument must have index type");
|
return emitError("compute_batch first block argument must have index type");
|
||||||
|
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||||
return emitError("compute_batch weight block argument types must match weight operand types exactly");
|
return emitError("compute_batch weight block argument types must match weight operand types exactly");
|
||||||
}
|
|
||||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||||
BlockArgument blockArg = getInputArgument(inputIndex);
|
BlockArgument blockArg = getInputArgument(inputIndex);
|
||||||
if (blockArg.getType() != input.getType())
|
if (blockArg.getType() != input.getType())
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#include "DCPAnalysis.hpp"
|
|
||||||
#include "../Scheduling/ComputeGraph.hpp"
|
#include "../Scheduling/ComputeGraph.hpp"
|
||||||
#include "../Scheduling/DcpScheduler.hpp"
|
#include "../Scheduling/DcpScheduler.hpp"
|
||||||
|
#include "DCPAnalysis.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
#include "MaterializeMergeSchedule.hpp"
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
@@ -24,6 +22,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "MaterializeMergeSchedule.hpp"
|
||||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
@@ -72,9 +71,7 @@ struct CpuSlotKey {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct CpuSlotKeyInfo {
|
struct CpuSlotKeyInfo {
|
||||||
static CpuSlotKey getEmptyKey() {
|
static CpuSlotKey getEmptyKey() { return {std::numeric_limits<CpuId>::max(), std::numeric_limits<SlotId>::max()}; }
|
||||||
return {std::numeric_limits<CpuId>::max(), std::numeric_limits<SlotId>::max()};
|
|
||||||
}
|
|
||||||
|
|
||||||
static CpuSlotKey getTombstoneKey() {
|
static CpuSlotKey getTombstoneKey() {
|
||||||
return {std::numeric_limits<CpuId>::max() - 1, std::numeric_limits<SlotId>::max()};
|
return {std::numeric_limits<CpuId>::max() - 1, std::numeric_limits<SlotId>::max()};
|
||||||
@@ -139,10 +136,11 @@ struct MaterializerState {
|
|||||||
DenseMap<Value, Value> hostReplacements;
|
DenseMap<Value, Value> hostReplacements;
|
||||||
DenseSet<Operation*> oldComputeOps;
|
DenseSet<Operation*> oldComputeOps;
|
||||||
|
|
||||||
MaterializerState(func::FuncOp func,
|
MaterializerState(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId)
|
||||||
const MergeScheduleResult& schedule,
|
: func(func),
|
||||||
int64_t& nextChannelId)
|
schedule(schedule),
|
||||||
: func(func), schedule(schedule), rewriter(func.getContext()), constantFolder(func.getContext()),
|
rewriter(func.getContext()),
|
||||||
|
constantFolder(func.getContext()),
|
||||||
nextChannelId(nextChannelId) {}
|
nextChannelId(nextChannelId) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -189,11 +187,12 @@ std::optional<uint32_t> getConstantFirstSliceOffset(tensor::ExtractSliceOp extra
|
|||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
ProducerKey getBatchLaneProducerKey(SpatComputeBatch batch,
|
ProducerKey
|
||||||
uint32_t laneStart,
|
getBatchLaneProducerKey(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount, size_t resultIndex) {
|
||||||
uint32_t laneCount,
|
return {
|
||||||
size_t resultIndex) {
|
{batch.getOperation(), laneStart, laneCount},
|
||||||
return {{batch.getOperation(), laneStart, laneCount}, resultIndex};
|
resultIndex
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex) {
|
ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex) {
|
||||||
@@ -202,8 +201,8 @@ ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex)
|
|||||||
|
|
||||||
bool isWholeBatchProducerKey(ProducerKey key) {
|
bool isWholeBatchProducerKey(ProducerKey key) {
|
||||||
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
|
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
|
||||||
return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0 &&
|
return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0
|
||||||
key.instance.laneCount == static_cast<uint32_t>(batch.getLaneCount());
|
&& key.instance.laneCount == static_cast<uint32_t>(batch.getLaneCount());
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<ProducerKey, 8> expandWholeBatchProducerKey(ProducerKey key) {
|
SmallVector<ProducerKey, 8> expandWholeBatchProducerKey(ProducerKey key) {
|
||||||
@@ -306,7 +305,10 @@ std::optional<ProducerKey> getProducerKey(Value value, const ComputeInstance* co
|
|||||||
auto result = dyn_cast<OpResult>(value);
|
auto result = dyn_cast<OpResult>(value);
|
||||||
if (!result)
|
if (!result)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
return ProducerKey {{compute.getOperation(), 0, 1}, result.getResultNumber()};
|
return ProducerKey {
|
||||||
|
{compute.getOperation(), 0, 1},
|
||||||
|
result.getResultNumber()
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(definingOp)) {
|
if (auto batch = dyn_cast<SpatComputeBatch>(definingOp)) {
|
||||||
@@ -316,10 +318,8 @@ std::optional<ProducerKey> getProducerKey(Value value, const ComputeInstance* co
|
|||||||
|
|
||||||
if (batch.getNumResults() != 0) {
|
if (batch.getNumResults() != 0) {
|
||||||
if (consumerInstance && isa<SpatComputeBatch>(consumerInstance->op))
|
if (consumerInstance && isa<SpatComputeBatch>(consumerInstance->op))
|
||||||
return getBatchLaneProducerKey(batch,
|
return getBatchLaneProducerKey(
|
||||||
consumerInstance->laneStart,
|
batch, consumerInstance->laneStart, consumerInstance->laneCount, result.getResultNumber());
|
||||||
consumerInstance->laneCount,
|
|
||||||
result.getResultNumber());
|
|
||||||
|
|
||||||
return getWholeBatchProducerKey(batch, result.getResultNumber());
|
return getWholeBatchProducerKey(batch, result.getResultNumber());
|
||||||
}
|
}
|
||||||
@@ -489,7 +489,8 @@ void createEmptyMaterializedOps(MaterializerState& state) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto batch = SpatComputeBatch::create(state.rewriter,
|
auto batch =
|
||||||
|
SpatComputeBatch::create(state.rewriter,
|
||||||
loc,
|
loc,
|
||||||
TypeRange(resultTypes),
|
TypeRange(resultTypes),
|
||||||
state.rewriter.getI32IntegerAttr(static_cast<int32_t>(materializedClass.cpus.size())),
|
state.rewriter.getI32IntegerAttr(static_cast<int32_t>(materializedClass.cpus.size())),
|
||||||
@@ -506,15 +507,13 @@ void createEmptyMaterializedOps(MaterializerState& state) {
|
|||||||
SmallVector<Location, 4> blockArgLocs {loc};
|
SmallVector<Location, 4> blockArgLocs {loc};
|
||||||
llvm::append_range(blockArgTypes, resultTypes);
|
llvm::append_range(blockArgTypes, resultTypes);
|
||||||
blockArgLocs.append(resultTypes.size(), loc);
|
blockArgLocs.append(resultTypes.size(), loc);
|
||||||
Block* body = state.rewriter.createBlock(
|
Block* body =
|
||||||
&batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
state.rewriter.createBlock(&batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
state.rewriter.setInsertionPointToEnd(body);
|
state.rewriter.setInsertionPointToEnd(body);
|
||||||
if (resultTypes.empty()) {
|
if (resultTypes.empty())
|
||||||
SpatYieldOp::create(state.rewriter, loc, ValueRange {});
|
SpatYieldOp::create(state.rewriter, loc, ValueRange {});
|
||||||
}
|
else
|
||||||
else {
|
|
||||||
SpatInParallelOp::create(state.rewriter, loc);
|
SpatInParallelOp::create(state.rewriter, loc);
|
||||||
}
|
|
||||||
materializedClass.op = batch.getOperation();
|
materializedClass.op = batch.getOperation();
|
||||||
materializedClass.body = body;
|
materializedClass.body = body;
|
||||||
state.rewriter.setInsertionPointAfter(batch.getOperation());
|
state.rewriter.setInsertionPointAfter(batch.getOperation());
|
||||||
@@ -559,7 +558,8 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali
|
|||||||
else {
|
else {
|
||||||
cast<SpatComputeBatch>(materializedClass.op).getInputsMutable().append(ValueRange(input));
|
cast<SpatComputeBatch>(materializedClass.op).getInputsMutable().append(ValueRange(input));
|
||||||
setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size());
|
setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size());
|
||||||
BlockArgument arg = materializedClass.body->insertArgument(materializedClass.body->getNumArguments()-1, input.getType(), input.getLoc());
|
BlockArgument arg = materializedClass.body->insertArgument(
|
||||||
|
materializedClass.body->getNumArguments() - 1, input.getType(), input.getLoc());
|
||||||
materializedClass.inputArgs[input] = arg;
|
materializedClass.inputArgs[input] = arg;
|
||||||
return arg;
|
return arg;
|
||||||
}
|
}
|
||||||
@@ -586,9 +586,8 @@ SmallVector<Value, 8> createIndexConstants(MaterializerState& state, Operation*
|
|||||||
return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened));
|
return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened));
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<SmallVector<ComputeInstance, 8>> getPeerInstances(MaterializerState& state,
|
FailureOr<SmallVector<ComputeInstance, 8>>
|
||||||
const MaterializedClass& materializedClass,
|
getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) {
|
||||||
SlotId slot) {
|
|
||||||
SmallVector<ComputeInstance, 8> peers;
|
SmallVector<ComputeInstance, 8> peers;
|
||||||
peers.reserve(materializedClass.cpus.size());
|
peers.reserve(materializedClass.cpus.size());
|
||||||
for (CpuId cpu : materializedClass.cpus) {
|
for (CpuId cpu : materializedClass.cpus) {
|
||||||
@@ -774,12 +773,8 @@ Value appendReceive(MaterializerState& state,
|
|||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
return SpatChannelReceiveOp::create(state.rewriter,
|
return SpatChannelReceiveOp::create(
|
||||||
loc,
|
state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
|
||||||
type,
|
|
||||||
channelIdValues.front(),
|
|
||||||
sourceCoreIdValues.front(),
|
|
||||||
targetCoreIdValues.front())
|
|
||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -802,19 +797,13 @@ Value appendHostReceive(MaterializerState& state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
assert(channelIds.size() == 1 && "scalar host receive expects one channel");
|
assert(channelIds.size() == 1 && "scalar host receive expects one channel");
|
||||||
return SpatChannelReceiveOp::create(state.rewriter,
|
return SpatChannelReceiveOp::create(
|
||||||
loc,
|
state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
|
||||||
type,
|
|
||||||
channelIdValues.front(),
|
|
||||||
sourceCoreIdValues.front(),
|
|
||||||
targetCoreIdValues.front())
|
|
||||||
.getOutput();
|
.getOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult setHostOutputValue(MaterializerState& state,
|
LogicalResult
|
||||||
MaterializedClass& sourceClass,
|
setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) {
|
||||||
Value originalOutput,
|
|
||||||
Value payload) {
|
|
||||||
auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput);
|
auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput);
|
||||||
if (resultIt == sourceClass.hostOutputToResultIndex.end())
|
if (resultIt == sourceClass.hostOutputToResultIndex.end())
|
||||||
return sourceClass.op->emitError("missing host result slot for materialized output");
|
return sourceClass.op->emitError("missing host result slot for materialized output");
|
||||||
@@ -908,8 +897,13 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
|||||||
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
|
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
|
||||||
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
|
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
|
||||||
appendScalarSend(state, sourceClass, payload, channelId, sourceCpu, targetCpu, loc);
|
appendScalarSend(state, sourceClass, payload, channelId, sourceCpu, targetCpu, loc);
|
||||||
Value received = appendReceive(state, targetClass, payload.getType(), ArrayRef<int64_t>(channelId),
|
Value received = appendReceive(state,
|
||||||
ArrayRef<int32_t>(sourceCpu), ArrayRef<int32_t>(targetCpu), loc);
|
targetClass,
|
||||||
|
payload.getType(),
|
||||||
|
ArrayRef<int64_t>(channelId),
|
||||||
|
ArrayRef<int32_t>(sourceCpu),
|
||||||
|
ArrayRef<int32_t>(targetCpu),
|
||||||
|
loc);
|
||||||
for (ProducerKey key : keys)
|
for (ProducerKey key : keys)
|
||||||
state.availableValues[key][targetClass.id] = received;
|
state.availableValues[key][targetClass.id] = received;
|
||||||
return success();
|
return success();
|
||||||
@@ -937,7 +931,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
|||||||
loc);
|
loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
Value received =
|
||||||
|
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
for (ProducerKey key : keys)
|
for (ProducerKey key : keys)
|
||||||
state.availableValues[key][targetClass.id] = received;
|
state.availableValues[key][targetClass.id] = received;
|
||||||
return success();
|
return success();
|
||||||
@@ -946,8 +941,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
|||||||
if (sourceClass.isBatch && !targetClass.isBatch) {
|
if (sourceClass.isBatch && !targetClass.isBatch) {
|
||||||
std::optional<ProducerKey> packedKey = getContiguousProducerKeyForKeys(keys);
|
std::optional<ProducerKey> packedKey = getContiguousProducerKeyForKeys(keys);
|
||||||
if (!packedKey)
|
if (!packedKey)
|
||||||
return sourceClass.op->emitError(
|
return sourceClass.op->emitError("cannot materialize batch-to-scalar communication as concat because source "
|
||||||
"cannot materialize batch-to-scalar communication as concat because source lanes are not contiguous in send order");
|
"lanes are not contiguous in send order");
|
||||||
|
|
||||||
FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(payload.getType(), keys.size());
|
FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(payload.getType(), keys.size());
|
||||||
if (failed(packedType))
|
if (failed(packedType))
|
||||||
@@ -975,7 +970,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
|||||||
|
|
||||||
if (sourceClass.isBatch && targetClass.isBatch) {
|
if (sourceClass.isBatch && targetClass.isBatch) {
|
||||||
if (sourceClass.cpus.size() != targetClass.cpus.size())
|
if (sourceClass.cpus.size() != targetClass.cpus.size())
|
||||||
return sourceClass.op->emitError("cannot materialize batch communication between equivalence classes of different sizes");
|
return sourceClass.op->emitError(
|
||||||
|
"cannot materialize batch communication between equivalence classes of different sizes");
|
||||||
|
|
||||||
SmallVector<int64_t, 8> channelIds;
|
SmallVector<int64_t, 8> channelIds;
|
||||||
SmallVector<int32_t, 8> sourceCoreIds;
|
SmallVector<int32_t, 8> sourceCoreIds;
|
||||||
@@ -991,7 +987,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
Value received =
|
||||||
|
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
for (ProducerKey key : keys)
|
for (ProducerKey key : keys)
|
||||||
state.availableValues[key][targetClass.id] = received;
|
state.availableValues[key][targetClass.id] = received;
|
||||||
return success();
|
return success();
|
||||||
@@ -1025,7 +1022,8 @@ LogicalResult emitHostCommunication(MaterializerState& state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
Value received = appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
Value received =
|
||||||
|
appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
|
||||||
state.hostReplacements[originalOutput] = received;
|
state.hostReplacements[originalOutput] = received;
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@@ -1041,7 +1039,8 @@ LogicalResult emitOutputFanout(MaterializerState& state,
|
|||||||
|
|
||||||
if (!sourceClass.isBatch) {
|
if (!sourceClass.isBatch) {
|
||||||
for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front()))
|
for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front()))
|
||||||
if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
if (failed(
|
||||||
|
emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
|
||||||
return failure();
|
return failure();
|
||||||
if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc)))
|
if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc)))
|
||||||
return failure();
|
return failure();
|
||||||
@@ -1065,11 +1064,8 @@ LogicalResult emitOutputFanout(MaterializerState& state,
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<Value> materializeWholeBatchInput(MaterializerState& state,
|
FailureOr<Value> materializeWholeBatchInput(
|
||||||
MaterializedClass& targetClass,
|
MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) {
|
||||||
ProducerKey key,
|
|
||||||
Type resultType,
|
|
||||||
Location loc) {
|
|
||||||
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
|
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
|
||||||
auto resultTensorType = dyn_cast<RankedTensorType>(resultType);
|
auto resultTensorType = dyn_cast<RankedTensorType>(resultType);
|
||||||
if (!batch || !resultTensorType || resultTensorType.getRank() == 0)
|
if (!batch || !resultTensorType || resultTensorType.getRank() == 0)
|
||||||
@@ -1200,9 +1196,8 @@ SmallVector<Value, 4> collectMappedBatchOutputs(SpatComputeBatch batch, IRMappin
|
|||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<SmallVector<Value, 4>> cloneInstanceBody(MaterializerState& state,
|
FailureOr<SmallVector<Value, 4>>
|
||||||
MaterializedClass& targetClass,
|
cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef<ComputeInstance> peers) {
|
||||||
ArrayRef<ComputeInstance> peers) {
|
|
||||||
assert(!peers.empty() && "expected at least one peer instance");
|
assert(!peers.empty() && "expected at least one peer instance");
|
||||||
const ComputeInstance& instance = peers.front();
|
const ComputeInstance& instance = peers.front();
|
||||||
Operation* sourceOp = instance.op;
|
Operation* sourceOp = instance.op;
|
||||||
@@ -1318,9 +1313,8 @@ LogicalResult eraseOldComputeOps(MaterializerState& state) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult MergeScheduleMaterializer::run(func::FuncOp func,
|
LogicalResult
|
||||||
const MergeScheduleResult& schedule,
|
MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) {
|
||||||
int64_t& nextChannelId) {
|
|
||||||
if (schedule.dominanceOrderCompute.empty())
|
if (schedule.dominanceOrderCompute.empty())
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
|
|||||||
@@ -10,8 +10,7 @@ namespace spatial {
|
|||||||
|
|
||||||
class MergeScheduleMaterializer {
|
class MergeScheduleMaterializer {
|
||||||
public:
|
public:
|
||||||
mlir::LogicalResult
|
mlir::LogicalResult run(mlir::func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId);
|
||||||
run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
@@ -57,8 +57,7 @@ bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != n
|
|||||||
class ScopedMergePhaseTimer {
|
class ScopedMergePhaseTimer {
|
||||||
public:
|
public:
|
||||||
explicit ScopedMergePhaseTimer(StringRef phaseName)
|
explicit ScopedMergePhaseTimer(StringRef phaseName)
|
||||||
: enabled(isMergeProfilingEnabled()),
|
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
|
||||||
phase(phaseName.str()) {
|
|
||||||
if (enabled)
|
if (enabled)
|
||||||
start = std::chrono::steady_clock::now();
|
start = std::chrono::steady_clock::now();
|
||||||
}
|
}
|
||||||
@@ -130,15 +129,12 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
|
|||||||
|
|
||||||
MergeIrCounts counts = collectMergeIrCounts(funcOp);
|
MergeIrCounts counts = collectMergeIrCounts(funcOp);
|
||||||
llvm::errs() << "[merge-profile] " << phaseName << " counts:"
|
llvm::errs() << "[merge-profile] " << phaseName << " counts:"
|
||||||
<< " compute=" << counts.topLevelComputeCount
|
<< " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
|
||||||
<< " compute_batch=" << counts.topLevelComputeBatchCount
|
|
||||||
<< " scalar_send=" << counts.scalarChannelSendCount
|
<< " scalar_send=" << counts.scalarChannelSendCount
|
||||||
<< " scalar_recv=" << counts.scalarChannelReceiveCount
|
<< " scalar_recv=" << counts.scalarChannelReceiveCount
|
||||||
<< " tensor_send=" << counts.tensorChannelSendCount
|
<< " tensor_send=" << counts.tensorChannelSendCount
|
||||||
<< " tensor_recv=" << counts.tensorChannelReceiveCount
|
<< " tensor_recv=" << counts.tensorChannelReceiveCount << " wvmm=" << counts.wvmmCount
|
||||||
<< " wvmm=" << counts.wvmmCount
|
<< " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
|
||||||
<< " vadd=" << counts.vaddCount
|
|
||||||
<< " scf_for=" << counts.scfForCount << "\n";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||||
@@ -167,7 +163,8 @@ bool isTrivialSerialMergeCandidate(SpatCompute compute) {
|
|||||||
return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size();
|
return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size();
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SmallVectorImpl<Value>& targetWeights, ValueRange sourceWeights) {
|
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SmallVectorImpl<Value>& targetWeights,
|
||||||
|
ValueRange sourceWeights) {
|
||||||
DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices;
|
DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices;
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(targetWeights))
|
for (auto [weightIndex, weight] : llvm::enumerate(targetWeights))
|
||||||
targetWeightIndices[weight].push_back(weightIndex);
|
targetWeightIndices[weight].push_back(weightIndex);
|
||||||
|
|||||||
@@ -707,8 +707,10 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||||
if (packedInput) {
|
if (packedInput) {
|
||||||
SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
|
SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
|
||||||
SmallVector<Value> sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
|
SmallVector<Value> sourceCoreIdValues =
|
||||||
SmallVector<Value> targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
|
createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
|
||||||
|
SmallVector<Value> targetCoreIdValues =
|
||||||
|
createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
|
||||||
spatial::SpatChannelSendTensorOp::create(
|
spatial::SpatChannelSendTensorOp::create(
|
||||||
rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput);
|
rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput);
|
||||||
for (auto op : run.ops)
|
for (auto op : run.ops)
|
||||||
|
|||||||
@@ -37,8 +37,7 @@ struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> {
|
|||||||
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance& value) {
|
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance& value) {
|
||||||
return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
|
return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
|
||||||
}
|
}
|
||||||
static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs,
|
static bool isEqual(const onnx_mlir::spatial::ComputeInstance& lhs, const onnx_mlir::spatial::ComputeInstance& rhs) {
|
||||||
const onnx_mlir::spatial::ComputeInstance &rhs) {
|
|
||||||
return lhs == rhs;
|
return lhs == rhs;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -10,8 +10,8 @@
|
|||||||
#include <queue>
|
#include <queue>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "DcpScheduler.hpp"
|
|
||||||
#include "../DCPGraph/Graph.hpp"
|
#include "../DCPGraph/Graph.hpp"
|
||||||
|
#include "DcpScheduler.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace spatial {
|
namespace spatial {
|
||||||
@@ -404,7 +404,8 @@ bool coarsenGraph(const VirtualGraph &graph,
|
|||||||
VirtualNode mergedNode;
|
VirtualNode mergedNode;
|
||||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||||
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
||||||
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), memberNode.originalNodeIndices.end());
|
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(),
|
||||||
|
memberNode.originalNodeIndices.end());
|
||||||
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
||||||
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
||||||
}
|
}
|
||||||
@@ -589,7 +590,8 @@ MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const Comp
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
|
MergeScheduleResult
|
||||||
|
runLegacyDcp(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context) {
|
||||||
llvm::SmallVector<Weight> nodeWeights;
|
llvm::SmallVector<Weight> nodeWeights;
|
||||||
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
||||||
llvm::SmallVector<int64_t> nodeOrderKeys;
|
llvm::SmallVector<int64_t> nodeOrderKeys;
|
||||||
|
|||||||
+10
-15
@@ -1,13 +1,13 @@
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/Support/ErrorHandling.h"
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ComputeGraph.hpp"
|
|
||||||
#include "../DCPGraph/DCPAnalysis.hpp"
|
#include "../DCPGraph/DCPAnalysis.hpp"
|
||||||
|
#include "ComputeGraph.hpp"
|
||||||
#include "DcpScheduler.hpp"
|
#include "DcpScheduler.hpp"
|
||||||
#include "MergeSchedulingAnalysis.hpp"
|
#include "MergeSchedulingAnalysis.hpp"
|
||||||
#include "PeftScheduler.hpp"
|
#include "PeftScheduler.hpp"
|
||||||
@@ -20,10 +20,8 @@ namespace {
|
|||||||
|
|
||||||
MergeSchedulerKind getSchedulerKind() {
|
MergeSchedulerKind getSchedulerKind() {
|
||||||
switch (pimMergeScheduler.getValue()) {
|
switch (pimMergeScheduler.getValue()) {
|
||||||
case MergeSchedulerPeft:
|
case MergeSchedulerPeft: return MergeSchedulerKind::Peft;
|
||||||
return MergeSchedulerKind::Peft;
|
case MergeSchedulerDcp: return MergeSchedulerKind::Dcp;
|
||||||
case MergeSchedulerDcp:
|
|
||||||
return MergeSchedulerKind::Dcp;
|
|
||||||
}
|
}
|
||||||
llvm_unreachable("unknown merge scheduler kind");
|
llvm_unreachable("unknown merge scheduler kind");
|
||||||
}
|
}
|
||||||
@@ -115,19 +113,16 @@ MergeScheduleResult MergeSchedulingAnalysis::run() {
|
|||||||
|
|
||||||
MergeScheduleResult schedule;
|
MergeScheduleResult schedule;
|
||||||
if (options.kind == MergeSchedulerKind::Peft) {
|
if (options.kind == MergeSchedulerKind::Peft) {
|
||||||
schedule = runPeftScheduler(
|
schedule = runPeftScheduler(graph,
|
||||||
graph,
|
PeftScheduleOptions {options.processorCount,
|
||||||
PeftScheduleOptions {options.processorCount, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
|
static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
|
||||||
entryOp->getContext()});
|
entryOp->getContext()});
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
schedule = runDcpScheduler(
|
schedule = runDcpScheduler(graph,
|
||||||
graph,
|
DcpScheduleOptions {options.processorCount,
|
||||||
DcpScheduleOptions {
|
|
||||||
options.processorCount,
|
|
||||||
dcpCriticalWindowSize.getValue(),
|
dcpCriticalWindowSize.getValue(),
|
||||||
options.allowDcpFallbackForAutoCoreCount
|
options.allowDcpFallbackForAutoCoreCount},
|
||||||
},
|
|
||||||
entryOp->getContext());
|
entryOp->getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,8 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
@@ -211,8 +211,9 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
|||||||
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
||||||
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
|
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
|
||||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
||||||
(void) withScalarCoreFromBatchLane(
|
(void) withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) {
|
||||||
coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) { return verifyCoreOperands(scalarCore, diagnostics); });
|
return verifyCoreOperands(scalarCore, diagnostics);
|
||||||
|
});
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user