automatic code reformat

This commit is contained in:
NiccoloN
2026-05-22 15:23:48 +02:00
parent d136136d22
commit 8337a11ce9
37 changed files with 312 additions and 354 deletions
+3 -3
View File
@@ -13,7 +13,8 @@
namespace onnx_mlir::pim { namespace onnx_mlir::pim {
struct CappedDiagnosticReporter { struct CappedDiagnosticReporter {
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {} explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8)
: maxReportedFailures(maxReportedFailures) {}
template <typename EmitFn> template <typename EmitFn>
void report(mlir::Operation* op, EmitFn&& emit) { void report(mlir::Operation* op, EmitFn&& emit) {
@@ -24,8 +25,7 @@ struct CappedDiagnosticReporter {
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const { void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
if (numFailures > maxReportedFailures) if (numFailures > maxReportedFailures)
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
<< failureDescription;
} }
bool hasFailure() const { return numFailures != 0; } bool hasFailure() const { return numFailures != 0; }
+4 -4
View File
@@ -1,7 +1,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#define DEBUG_TYPE "PimCompilerOptions" #define DEBUG_TYPE "PimCompilerOptions"
namespace onnx_mlir { namespace onnx_mlir {
@@ -15,8 +15,8 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
llvm::cl::init(EmitPimCodegen), llvm::cl::init(EmitPimCodegen),
llvm::cl::cat(OnnxMlirOptions)); llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler( llvm::cl::opt<PimMergeSchedulerType>
"pim-merge-scheduler", pimMergeScheduler("pim-merge-scheduler",
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"), llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")), llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")), llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
@@ -99,15 +99,17 @@ auto createSpatCompute(RewriterT& rewriter,
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>; using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) { if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithValues( detail::invokeWithValues(std::forward<BodyFn>(body),
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {}); detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
return computeOp; return computeOp;
} }
else { else {
auto bodyResult = detail::invokeWithValues( auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {}); detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) { if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp); rewriter.eraseOp(computeOp);
@@ -423,8 +423,11 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
SmallVector<Value> vmmOutputs; SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlices[coreId].size()); vmmOutputs.reserve(aHSlices[coreId].size());
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size())) for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
vmmOutputs.push_back(spatial::SpatVMMOp::create( vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter,
rewriter, gemmLoc, currOutHSliceType, computeOp.getWeightArgument(aHSliceId), computeOp.getInputArgument(aHSliceId))); gemmLoc,
currOutHSliceType,
computeOp.getWeightArgument(aHSliceId),
computeOp.getInputArgument(aHSliceId)));
if (vmmOutputs.empty()) { if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs"); gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure(); return failure();
@@ -579,8 +582,8 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)}; SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))}; SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
tensor::ParallelInsertSliceOp::create(rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, tensor::ParallelInsertSliceOp::create(
unitStrides); rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, unitStrides);
rewriter.setInsertionPointAfter(batchOp); rewriter.setInsertionPointAfter(batchOp);
rewriter.replaceOp(gemmOp, batchOp.getResults()); rewriter.replaceOp(gemmOp, batchOp.getResults());
@@ -38,23 +38,16 @@ static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end()); return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
} }
static Value collapseBatchDims(Value value, static Value
int64_t batchSize, collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
auto type = cast<RankedTensorType>(value.getType()); auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2 || type.getRank() == 3) if (type.getRank() == 2 || type.getRank() == 3)
return value; return value;
auto collapsedType = auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding()); SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)}, ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)} ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}};
};
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim) for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
reassociation.front().push_back(dim); reassociation.front().push_back(dim);
@@ -72,19 +65,14 @@ static Value collapseBatchDims(Value value,
return collapseCompute.getResult(0); return collapseCompute.getResult(0);
} }
static Value expandBatchDims(Value value, static Value
RankedTensorType outputType, expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) {
size_t batchRank,
PatternRewriter& rewriter,
Location loc) {
if (cast<RankedTensorType>(value.getType()) == outputType) if (cast<RankedTensorType>(value.getType()) == outputType)
return value; return value;
SmallVector<ReassociationIndices> reassociation = { SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(batchRank)}, ReassociationIndices {static_cast<int64_t>(batchRank)},
ReassociationIndices {static_cast<int64_t>(batchRank + 1)} ReassociationIndices {static_cast<int64_t>(batchRank + 1)}};
};
for (size_t dim = 0; dim < batchRank; ++dim) for (size_t dim = 0; dim < batchRank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim)); reassociation.front().push_back(static_cast<int64_t>(dim));
@@ -58,24 +58,21 @@ static Value buildNearestResizeLoop(Value input,
Value outputC = channelLoop.getInductionVar(); Value outputC = channelLoop.getInductionVar();
Value outputChannelAcc = channelLoop.getRegionIterArgs().front(); Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
Value inputC = Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc}); auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
rewriter.setInsertionPointToStart(heightLoop.getBody()); rewriter.setInsertionPointToStart(heightLoop.getBody());
Value outputH = heightLoop.getInductionVar(); Value outputH = heightLoop.getInductionVar();
Value outputHeightAcc = heightLoop.getRegionIterArgs().front(); Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
Value inputH = Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc}); auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
rewriter.setInsertionPointToStart(widthLoop.getBody()); rewriter.setInsertionPointToStart(widthLoop.getBody());
Value outputW = widthLoop.getInductionVar(); Value outputW = widthLoop.getInductionVar();
Value outputWidthAcc = widthLoop.getRegionIterArgs().front(); Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
Value inputW = Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW}; SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
Value inputSlice = Value inputSlice =
@@ -114,8 +111,8 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric" if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|| resizeOp.getNearestMode() != "floor") || resizeOp.getNearestMode() != "floor")
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(resizeOp,
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor."); "resize lowering currently supports only nearest + asymmetric + floor.");
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; }) if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; })) || llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
@@ -94,8 +94,8 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
} }
llvm::append_range(newBlockArgTypes, newInputTypes); llvm::append_range(newBlockArgTypes, newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs); llvm::append_range(newBlockArgLocs, newInputLocs);
auto* newBlock = auto* newBlock = rewriter.createBlock(
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs); &newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes( newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())}); {static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock); rewriter.setInsertionPointToStart(newBlock);
@@ -228,7 +228,8 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
mapper.map(oldArg, *clonedValue); mapper.map(oldArg, *clonedValue);
} }
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults()))
mapper.map(compute.getOutputArgument(resultIndex), newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex)); mapper.map(compute.getOutputArgument(resultIndex),
newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
for (Operation& op : oldBlock) for (Operation& op : oldBlock)
rewriter.clone(op, mapper); rewriter.clone(op, mapper);
@@ -1,8 +1,8 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
@@ -141,8 +141,8 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale); scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
} }
totalOffset = totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() totalOffset =
: scaledOffset; totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() : scaledOffset;
} }
if (!totalOffset) if (!totalOffset)
@@ -233,10 +233,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
if (!computeOp.getWeights().empty()) if (!computeOp.getWeights().empty())
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end()); computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
auto coreOp = PimCoreOp::create(rewriter, auto coreOp = PimCoreOp::create(
loc, rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
ValueRange(computeWeights),
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
rewriter.setInsertionPointToStart(&block); rewriter.setInsertionPointToStart(&block);
auto& coreOpBlocks = coreOp.getBody().getBlocks(); auto& coreOpBlocks = coreOp.getBody().getBlocks();
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) { for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
@@ -217,8 +217,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
} // namespace } // namespace
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) { void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>( patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(patterns.getContext());
patterns.getContext());
} }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -546,11 +546,8 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
return ReturnPathLoweringResult::NotReturnPath; return ReturnPathLoweringResult::NotReturnPath;
} }
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
raptor::SpatialToPimPass::lowerComputeResultReturnPath(spatial::SpatCompute computeOp, spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
OpResult result,
Value yieldValue,
IRRewriter& rewriter) {
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter); return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
} }
@@ -39,13 +39,10 @@ private:
size_t coreId = 0; size_t coreId = 0;
llvm::SmallVector<mlir::Operation*> operationsToRemove; llvm::SmallVector<mlir::Operation*> operationsToRemove;
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
mlir::IRRewriter& rewriter); mlir::LogicalResult
mlir::LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder);
mlir::IRRewriter& rewriter, mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter);
mlir::OperationFolder& constantFolder);
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
mlir::IRRewriter& rewriter);
enum class ReturnPathLoweringResult { enum class ReturnPathLoweringResult {
Handled, Handled,
+2 -2
View File
@@ -1,7 +1,7 @@
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include <string> #include <string>
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
+4 -5
View File
@@ -56,7 +56,8 @@ static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<O
return parser.parseRParen(); return parser.parseRParen();
} }
static void printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) { static void
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
printCompressedValueList(printer, arguments, delimiter); printCompressedValueList(printer, arguments, delimiter);
printer << " = "; printer << " = ";
printCompressedValueList(printer, operands, delimiter); printCompressedValueList(printer, operands, delimiter);
@@ -82,10 +83,8 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult { auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
switch (currentDelimiter) { switch (currentDelimiter) {
case ListDelimiter::Paren: case ListDelimiter::Paren: return parser.parseRParen();
return parser.parseRParen(); case ListDelimiter::Square: return parser.parseRSquare();
case ListDelimiter::Square:
return parser.parseRSquare();
} }
llvm_unreachable("unsupported delimiter"); llvm_unreachable("unsupported delimiter");
}; };
+7 -10
View File
@@ -1,7 +1,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
@@ -66,8 +66,8 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
|| isExplicitHostOperand(op, operand.getOperandNumber())) || isExplicitHostOperand(op, operand.getOperandNumber()))
continue; continue;
InFlightDiagnostic diagnostic = InFlightDiagnostic diagnostic = ownerOp->emitOpError()
ownerOp->emitOpError() << kind << " body may only directly reference external constants"; << kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc()) diagnostic.attachNote(op->getLoc())
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName(); << "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
hasFailure = true; hasFailure = true;
@@ -153,10 +153,9 @@ LogicalResult PimCoreOp::verify() {
Block& block = getBody().front(); Block& block = getBody().front();
if (block.getNumArguments() != getWeights().size()) if (block.getNumArguments() != getWeights().size())
return emitError("core body must have one block argument per weight"); return emitError("core body must have one block argument per weight");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType()) if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core weight block argument types must match weight operand types exactly"); return emitError("core weight block argument types must match weight operand types exactly");
}
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core"); return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
} }
@@ -169,14 +168,12 @@ LogicalResult PimCoreBatchOp::verify() {
return emitError("core_batch body must have lane, weight, and input block arguments"); return emitError("core_batch body must have lane, weight, and input block arguments");
if (!getLaneArgument().getType().isIndex()) if (!getLaneArgument().getType().isIndex())
return emitError("core_batch first block argument must have index type"); return emitError("core_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType()) if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core_batch weight block argument types must match weight operand types exactly"); return emitError("core_batch weight block argument types must match weight operand types exactly");
} for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
if (getInputArgument(inputIndex).getType() != input.getType()) if (getInputArgument(inputIndex).getType() != input.getType())
return emitError("core_batch input block argument types must match input operand types exactly"); return emitError("core_batch input block argument types must match input operand types exactly");
}
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch"); return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
} }
@@ -50,11 +50,10 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Opera
pendingValues.push_back(result); pendingValues.push_back(result);
if (auto forOp = dyn_cast<scf::ForOp>(user)) { if (auto forOp = dyn_cast<scf::ForOp>(user)) {
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) { for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs()))
if (initArg == value) if (initArg == value)
pendingValues.push_back(forOp.getResult(index)); pendingValues.push_back(forOp.getResult(index));
} }
}
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) { if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
for (OpResult result : user->getResults()) { for (OpResult result : user->getResults()) {
+3 -5
View File
@@ -1,7 +1,7 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include <string> #include <string>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
@@ -65,9 +65,7 @@ void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); } OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
return getRegion().front().getOperations();
}
void SpatialDialect::initialize() { void SpatialDialect::initialize() {
addTypes< addTypes<
+3 -6
View File
@@ -104,16 +104,13 @@ static ParseResult parseBoundValueList(OpAsmParser& parser,
return failure(); return failure();
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult { auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
switch (currentDelimiter) { switch (currentDelimiter) {
case ListDelimiter::Paren: case ListDelimiter::Paren: return parser.parseRParen();
return parser.parseRParen(); case ListDelimiter::Square: return parser.parseRSquare();
case ListDelimiter::Square:
return parser.parseRSquare();
} }
llvm_unreachable("unsupported delimiter"); llvm_unreachable("unsupported delimiter");
}; };
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) { if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
return failure(); return failure();
}
return success(); return success();
} }
+5 -8
View File
@@ -273,8 +273,8 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region
InFlightDiagnostic diagnostic = ownerOp->emitOpError() InFlightDiagnostic diagnostic = ownerOp->emitOpError()
<< kind << " body may only directly reference external constants"; << kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber() diagnostic.attachNote(op->getLoc())
<< " is used by " << op->getName(); << "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
hasFailure = true; hasFailure = true;
} }
}); });
@@ -457,14 +457,12 @@ LogicalResult SpatCompute::verify() {
if (block.getNumArguments() != expectedArgCount) if (block.getNumArguments() != expectedArgCount)
return emitError("compute body must have weight and input block arguments"); return emitError("compute body must have weight and input block arguments");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType()) if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("compute weight block argument types must match weight operand types exactly"); return emitError("compute weight block argument types must match weight operand types exactly");
} for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
if (getInputArgument(inputIndex).getType() != input.getType()) if (getInputArgument(inputIndex).getType() != input.getType())
return emitError("compute input block argument types must match input operand types exactly"); return emitError("compute input block argument types must match input operand types exactly");
}
if (block.mightHaveTerminator()) { if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator()); auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
@@ -582,10 +580,9 @@ LogicalResult SpatComputeBatch::verify() {
if (!getLaneArgument().getType().isIndex()) if (!getLaneArgument().getType().isIndex())
return emitError("compute_batch first block argument must have index type"); return emitError("compute_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType()) if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("compute_batch weight block argument types must match weight operand types exactly"); return emitError("compute_batch weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
BlockArgument blockArg = getInputArgument(inputIndex); BlockArgument blockArg = getInputArgument(inputIndex);
if (blockArg.getType() != input.getType()) if (blockArg.getType() != input.getType())
@@ -1,6 +1,6 @@
#include "DCPAnalysis.hpp"
#include "../Scheduling/ComputeGraph.hpp" #include "../Scheduling/ComputeGraph.hpp"
#include "../Scheduling/DcpScheduler.hpp" #include "../Scheduling/DcpScheduler.hpp"
#include "DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -1,5 +1,3 @@
#include "MaterializeMergeSchedule.hpp"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -24,6 +22,7 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "MaterializeMergeSchedule.hpp"
#include "Scheduling/ComputeInstanceUtils.hpp" #include "Scheduling/ComputeInstanceUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -72,9 +71,7 @@ struct CpuSlotKey {
}; };
struct CpuSlotKeyInfo { struct CpuSlotKeyInfo {
static CpuSlotKey getEmptyKey() { static CpuSlotKey getEmptyKey() { return {std::numeric_limits<CpuId>::max(), std::numeric_limits<SlotId>::max()}; }
return {std::numeric_limits<CpuId>::max(), std::numeric_limits<SlotId>::max()};
}
static CpuSlotKey getTombstoneKey() { static CpuSlotKey getTombstoneKey() {
return {std::numeric_limits<CpuId>::max() - 1, std::numeric_limits<SlotId>::max()}; return {std::numeric_limits<CpuId>::max() - 1, std::numeric_limits<SlotId>::max()};
@@ -139,10 +136,11 @@ struct MaterializerState {
DenseMap<Value, Value> hostReplacements; DenseMap<Value, Value> hostReplacements;
DenseSet<Operation*> oldComputeOps; DenseSet<Operation*> oldComputeOps;
MaterializerState(func::FuncOp func, MaterializerState(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId)
const MergeScheduleResult& schedule, : func(func),
int64_t& nextChannelId) schedule(schedule),
: func(func), schedule(schedule), rewriter(func.getContext()), constantFolder(func.getContext()), rewriter(func.getContext()),
constantFolder(func.getContext()),
nextChannelId(nextChannelId) {} nextChannelId(nextChannelId) {}
}; };
@@ -189,11 +187,12 @@ std::optional<uint32_t> getConstantFirstSliceOffset(tensor::ExtractSliceOp extra
return std::nullopt; return std::nullopt;
} }
ProducerKey getBatchLaneProducerKey(SpatComputeBatch batch, ProducerKey
uint32_t laneStart, getBatchLaneProducerKey(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount, size_t resultIndex) {
uint32_t laneCount, return {
size_t resultIndex) { {batch.getOperation(), laneStart, laneCount},
return {{batch.getOperation(), laneStart, laneCount}, resultIndex}; resultIndex
};
} }
ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex) { ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex) {
@@ -202,8 +201,8 @@ ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex)
bool isWholeBatchProducerKey(ProducerKey key) { bool isWholeBatchProducerKey(ProducerKey key) {
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op); auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0 && return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0
key.instance.laneCount == static_cast<uint32_t>(batch.getLaneCount()); && key.instance.laneCount == static_cast<uint32_t>(batch.getLaneCount());
} }
SmallVector<ProducerKey, 8> expandWholeBatchProducerKey(ProducerKey key) { SmallVector<ProducerKey, 8> expandWholeBatchProducerKey(ProducerKey key) {
@@ -306,7 +305,10 @@ std::optional<ProducerKey> getProducerKey(Value value, const ComputeInstance* co
auto result = dyn_cast<OpResult>(value); auto result = dyn_cast<OpResult>(value);
if (!result) if (!result)
return std::nullopt; return std::nullopt;
return ProducerKey {{compute.getOperation(), 0, 1}, result.getResultNumber()}; return ProducerKey {
{compute.getOperation(), 0, 1},
result.getResultNumber()
};
} }
if (auto batch = dyn_cast<SpatComputeBatch>(definingOp)) { if (auto batch = dyn_cast<SpatComputeBatch>(definingOp)) {
@@ -316,10 +318,8 @@ std::optional<ProducerKey> getProducerKey(Value value, const ComputeInstance* co
if (batch.getNumResults() != 0) { if (batch.getNumResults() != 0) {
if (consumerInstance && isa<SpatComputeBatch>(consumerInstance->op)) if (consumerInstance && isa<SpatComputeBatch>(consumerInstance->op))
return getBatchLaneProducerKey(batch, return getBatchLaneProducerKey(
consumerInstance->laneStart, batch, consumerInstance->laneStart, consumerInstance->laneCount, result.getResultNumber());
consumerInstance->laneCount,
result.getResultNumber());
return getWholeBatchProducerKey(batch, result.getResultNumber()); return getWholeBatchProducerKey(batch, result.getResultNumber());
} }
@@ -489,7 +489,8 @@ void createEmptyMaterializedOps(MaterializerState& state) {
continue; continue;
} }
auto batch = SpatComputeBatch::create(state.rewriter, auto batch =
SpatComputeBatch::create(state.rewriter,
loc, loc,
TypeRange(resultTypes), TypeRange(resultTypes),
state.rewriter.getI32IntegerAttr(static_cast<int32_t>(materializedClass.cpus.size())), state.rewriter.getI32IntegerAttr(static_cast<int32_t>(materializedClass.cpus.size())),
@@ -506,15 +507,13 @@ void createEmptyMaterializedOps(MaterializerState& state) {
SmallVector<Location, 4> blockArgLocs {loc}; SmallVector<Location, 4> blockArgLocs {loc};
llvm::append_range(blockArgTypes, resultTypes); llvm::append_range(blockArgTypes, resultTypes);
blockArgLocs.append(resultTypes.size(), loc); blockArgLocs.append(resultTypes.size(), loc);
Block* body = state.rewriter.createBlock( Block* body =
&batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); state.rewriter.createBlock(&batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
state.rewriter.setInsertionPointToEnd(body); state.rewriter.setInsertionPointToEnd(body);
if (resultTypes.empty()) { if (resultTypes.empty())
SpatYieldOp::create(state.rewriter, loc, ValueRange {}); SpatYieldOp::create(state.rewriter, loc, ValueRange {});
} else
else {
SpatInParallelOp::create(state.rewriter, loc); SpatInParallelOp::create(state.rewriter, loc);
}
materializedClass.op = batch.getOperation(); materializedClass.op = batch.getOperation();
materializedClass.body = body; materializedClass.body = body;
state.rewriter.setInsertionPointAfter(batch.getOperation()); state.rewriter.setInsertionPointAfter(batch.getOperation());
@@ -559,7 +558,8 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali
else { else {
cast<SpatComputeBatch>(materializedClass.op).getInputsMutable().append(ValueRange(input)); cast<SpatComputeBatch>(materializedClass.op).getInputsMutable().append(ValueRange(input));
setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size()); setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size());
BlockArgument arg = materializedClass.body->insertArgument(materializedClass.body->getNumArguments()-1, input.getType(), input.getLoc()); BlockArgument arg = materializedClass.body->insertArgument(
materializedClass.body->getNumArguments() - 1, input.getType(), input.getLoc());
materializedClass.inputArgs[input] = arg; materializedClass.inputArgs[input] = arg;
return arg; return arg;
} }
@@ -586,9 +586,8 @@ SmallVector<Value, 8> createIndexConstants(MaterializerState& state, Operation*
return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened)); return createIndexConstants(state, anchor, ArrayRef<int64_t>(widened));
} }
FailureOr<SmallVector<ComputeInstance, 8>> getPeerInstances(MaterializerState& state, FailureOr<SmallVector<ComputeInstance, 8>>
const MaterializedClass& materializedClass, getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) {
SlotId slot) {
SmallVector<ComputeInstance, 8> peers; SmallVector<ComputeInstance, 8> peers;
peers.reserve(materializedClass.cpus.size()); peers.reserve(materializedClass.cpus.size());
for (CpuId cpu : materializedClass.cpus) { for (CpuId cpu : materializedClass.cpus) {
@@ -774,12 +773,8 @@ Value appendReceive(MaterializerState& state,
.getOutput(); .getOutput();
} }
return SpatChannelReceiveOp::create(state.rewriter, return SpatChannelReceiveOp::create(
loc, state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
type,
channelIdValues.front(),
sourceCoreIdValues.front(),
targetCoreIdValues.front())
.getOutput(); .getOutput();
} }
@@ -802,19 +797,13 @@ Value appendHostReceive(MaterializerState& state,
} }
assert(channelIds.size() == 1 && "scalar host receive expects one channel"); assert(channelIds.size() == 1 && "scalar host receive expects one channel");
return SpatChannelReceiveOp::create(state.rewriter, return SpatChannelReceiveOp::create(
loc, state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front())
type,
channelIdValues.front(),
sourceCoreIdValues.front(),
targetCoreIdValues.front())
.getOutput(); .getOutput();
} }
LogicalResult setHostOutputValue(MaterializerState& state, LogicalResult
MaterializedClass& sourceClass, setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) {
Value originalOutput,
Value payload) {
auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput); auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput);
if (resultIt == sourceClass.hostOutputToResultIndex.end()) if (resultIt == sourceClass.hostOutputToResultIndex.end())
return sourceClass.op->emitError("missing host result slot for materialized output"); return sourceClass.op->emitError("missing host result slot for materialized output");
@@ -908,8 +897,13 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front()); int32_t sourceCpu = static_cast<int32_t>(sourceClass.cpus.front());
int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front()); int32_t targetCpu = static_cast<int32_t>(targetClass.cpus.front());
appendScalarSend(state, sourceClass, payload, channelId, sourceCpu, targetCpu, loc); appendScalarSend(state, sourceClass, payload, channelId, sourceCpu, targetCpu, loc);
Value received = appendReceive(state, targetClass, payload.getType(), ArrayRef<int64_t>(channelId), Value received = appendReceive(state,
ArrayRef<int32_t>(sourceCpu), ArrayRef<int32_t>(targetCpu), loc); targetClass,
payload.getType(),
ArrayRef<int64_t>(channelId),
ArrayRef<int32_t>(sourceCpu),
ArrayRef<int32_t>(targetCpu),
loc);
for (ProducerKey key : keys) for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received; state.availableValues[key][targetClass.id] = received;
return success(); return success();
@@ -937,7 +931,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
loc); loc);
} }
Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); Value received =
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys) for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received; state.availableValues[key][targetClass.id] = received;
return success(); return success();
@@ -946,8 +941,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
if (sourceClass.isBatch && !targetClass.isBatch) { if (sourceClass.isBatch && !targetClass.isBatch) {
std::optional<ProducerKey> packedKey = getContiguousProducerKeyForKeys(keys); std::optional<ProducerKey> packedKey = getContiguousProducerKeyForKeys(keys);
if (!packedKey) if (!packedKey)
return sourceClass.op->emitError( return sourceClass.op->emitError("cannot materialize batch-to-scalar communication as concat because source "
"cannot materialize batch-to-scalar communication as concat because source lanes are not contiguous in send order"); "lanes are not contiguous in send order");
FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(payload.getType(), keys.size()); FailureOr<RankedTensorType> packedType = getPackedBatchTensorType(payload.getType(), keys.size());
if (failed(packedType)) if (failed(packedType))
@@ -975,7 +970,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
if (sourceClass.isBatch && targetClass.isBatch) { if (sourceClass.isBatch && targetClass.isBatch) {
if (sourceClass.cpus.size() != targetClass.cpus.size()) if (sourceClass.cpus.size() != targetClass.cpus.size())
return sourceClass.op->emitError("cannot materialize batch communication between equivalence classes of different sizes"); return sourceClass.op->emitError(
"cannot materialize batch communication between equivalence classes of different sizes");
SmallVector<int64_t, 8> channelIds; SmallVector<int64_t, 8> channelIds;
SmallVector<int32_t, 8> sourceCoreIds; SmallVector<int32_t, 8> sourceCoreIds;
@@ -991,7 +987,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state,
} }
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); Value received =
appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
for (ProducerKey key : keys) for (ProducerKey key : keys)
state.availableValues[key][targetClass.id] = received; state.availableValues[key][targetClass.id] = received;
return success(); return success();
@@ -1025,7 +1022,8 @@ LogicalResult emitHostCommunication(MaterializerState& state,
} }
appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc);
Value received = appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); Value received =
appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc);
state.hostReplacements[originalOutput] = received; state.hostReplacements[originalOutput] = received;
return success(); return success();
} }
@@ -1041,7 +1039,8 @@ LogicalResult emitOutputFanout(MaterializerState& state,
if (!sourceClass.isBatch) { if (!sourceClass.isBatch) {
for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front())) for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front()))
if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) if (failed(
emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc)))
return failure(); return failure();
if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc))) if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc)))
return failure(); return failure();
@@ -1065,11 +1064,8 @@ LogicalResult emitOutputFanout(MaterializerState& state,
return success(); return success();
} }
FailureOr<Value> materializeWholeBatchInput(MaterializerState& state, FailureOr<Value> materializeWholeBatchInput(
MaterializedClass& targetClass, MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) {
ProducerKey key,
Type resultType,
Location loc) {
auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op); auto batch = dyn_cast_or_null<SpatComputeBatch>(key.instance.op);
auto resultTensorType = dyn_cast<RankedTensorType>(resultType); auto resultTensorType = dyn_cast<RankedTensorType>(resultType);
if (!batch || !resultTensorType || resultTensorType.getRank() == 0) if (!batch || !resultTensorType || resultTensorType.getRank() == 0)
@@ -1200,9 +1196,8 @@ SmallVector<Value, 4> collectMappedBatchOutputs(SpatComputeBatch batch, IRMappin
return outputs; return outputs;
} }
FailureOr<SmallVector<Value, 4>> cloneInstanceBody(MaterializerState& state, FailureOr<SmallVector<Value, 4>>
MaterializedClass& targetClass, cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef<ComputeInstance> peers) {
ArrayRef<ComputeInstance> peers) {
assert(!peers.empty() && "expected at least one peer instance"); assert(!peers.empty() && "expected at least one peer instance");
const ComputeInstance& instance = peers.front(); const ComputeInstance& instance = peers.front();
Operation* sourceOp = instance.op; Operation* sourceOp = instance.op;
@@ -1318,9 +1313,8 @@ LogicalResult eraseOldComputeOps(MaterializerState& state) {
} // namespace } // namespace
LogicalResult MergeScheduleMaterializer::run(func::FuncOp func, LogicalResult
const MergeScheduleResult& schedule, MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) {
int64_t& nextChannelId) {
if (schedule.dominanceOrderCompute.empty()) if (schedule.dominanceOrderCompute.empty())
return success(); return success();
@@ -10,8 +10,7 @@ namespace spatial {
class MergeScheduleMaterializer { class MergeScheduleMaterializer {
public: public:
mlir::LogicalResult mlir::LogicalResult run(mlir::func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId);
run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId);
}; };
} // namespace spatial } // namespace spatial
@@ -57,8 +57,7 @@ bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != n
class ScopedMergePhaseTimer { class ScopedMergePhaseTimer {
public: public:
explicit ScopedMergePhaseTimer(StringRef phaseName) explicit ScopedMergePhaseTimer(StringRef phaseName)
: enabled(isMergeProfilingEnabled()), : enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
phase(phaseName.str()) {
if (enabled) if (enabled)
start = std::chrono::steady_clock::now(); start = std::chrono::steady_clock::now();
} }
@@ -130,15 +129,12 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) {
MergeIrCounts counts = collectMergeIrCounts(funcOp); MergeIrCounts counts = collectMergeIrCounts(funcOp);
llvm::errs() << "[merge-profile] " << phaseName << " counts:" llvm::errs() << "[merge-profile] " << phaseName << " counts:"
<< " compute=" << counts.topLevelComputeCount << " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount
<< " compute_batch=" << counts.topLevelComputeBatchCount
<< " scalar_send=" << counts.scalarChannelSendCount << " scalar_send=" << counts.scalarChannelSendCount
<< " scalar_recv=" << counts.scalarChannelReceiveCount << " scalar_recv=" << counts.scalarChannelReceiveCount
<< " tensor_send=" << counts.tensorChannelSendCount << " tensor_send=" << counts.tensorChannelSendCount
<< " tensor_recv=" << counts.tensorChannelReceiveCount << " tensor_recv=" << counts.tensorChannelReceiveCount << " wvmm=" << counts.wvmmCount
<< " wvmm=" << counts.wvmmCount << " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n";
<< " vadd=" << counts.vaddCount
<< " scf_for=" << counts.scfForCount << "\n";
} }
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) { static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
@@ -167,7 +163,8 @@ bool isTrivialSerialMergeCandidate(SpatCompute compute) {
return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size(); return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size();
} }
SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SmallVectorImpl<Value>& targetWeights, ValueRange sourceWeights) { SmallVector<size_t> appendMissingWeightsAndBuildIndexMap(SmallVectorImpl<Value>& targetWeights,
ValueRange sourceWeights) {
DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices; DenseMap<Value, SmallVector<size_t, 4>> targetWeightIndices;
for (auto [weightIndex, weight] : llvm::enumerate(targetWeights)) for (auto [weightIndex, weight] : llvm::enumerate(targetWeights))
targetWeightIndices[weight].push_back(weightIndex); targetWeightIndices[weight].push_back(weightIndex);
@@ -707,8 +707,10 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc()); Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
if (packedInput) { if (packedInput) {
SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder); SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
SmallVector<Value> sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder); SmallVector<Value> sourceCoreIdValues =
SmallVector<Value> targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder); createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
SmallVector<Value> targetCoreIdValues =
createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
spatial::SpatChannelSendTensorOp::create( spatial::SpatChannelSendTensorOp::create(
rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput); rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput);
for (auto op : run.ops) for (auto op : run.ops)
@@ -37,8 +37,7 @@ struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> {
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance& value) { static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance& value) {
return llvm::hash_combine(value.op, value.laneStart, value.laneCount); return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
} }
static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs, static bool isEqual(const onnx_mlir::spatial::ComputeInstance& lhs, const onnx_mlir::spatial::ComputeInstance& rhs) {
const onnx_mlir::spatial::ComputeInstance &rhs) {
return lhs == rhs; return lhs == rhs;
} }
}; };
@@ -10,8 +10,8 @@
#include <queue> #include <queue>
#include <vector> #include <vector>
#include "DcpScheduler.hpp"
#include "../DCPGraph/Graph.hpp" #include "../DCPGraph/Graph.hpp"
#include "DcpScheduler.hpp"
namespace onnx_mlir { namespace onnx_mlir {
namespace spatial { namespace spatial {
@@ -404,7 +404,8 @@ bool coarsenGraph(const VirtualGraph &graph,
VirtualNode mergedNode; VirtualNode mergedNode;
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) { for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
const VirtualNode& memberNode = graph.nodes[memberIndex]; const VirtualNode& memberNode = graph.nodes[memberIndex];
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), memberNode.originalNodeIndices.end()); mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(),
memberNode.originalNodeIndices.end());
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight); mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage); mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
} }
@@ -589,7 +590,8 @@ MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const Comp
return result; return result;
} }
MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) { MergeScheduleResult
runLegacyDcp(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context) {
llvm::SmallVector<Weight> nodeWeights; llvm::SmallVector<Weight> nodeWeights;
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage; llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
llvm::SmallVector<int64_t> nodeOrderKeys; llvm::SmallVector<int64_t> nodeOrderKeys;
@@ -1,13 +1,13 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "ComputeGraph.hpp"
#include "../DCPGraph/DCPAnalysis.hpp" #include "../DCPGraph/DCPAnalysis.hpp"
#include "ComputeGraph.hpp"
#include "DcpScheduler.hpp" #include "DcpScheduler.hpp"
#include "MergeSchedulingAnalysis.hpp" #include "MergeSchedulingAnalysis.hpp"
#include "PeftScheduler.hpp" #include "PeftScheduler.hpp"
@@ -20,10 +20,8 @@ namespace {
MergeSchedulerKind getSchedulerKind() { MergeSchedulerKind getSchedulerKind() {
switch (pimMergeScheduler.getValue()) { switch (pimMergeScheduler.getValue()) {
case MergeSchedulerPeft: case MergeSchedulerPeft: return MergeSchedulerKind::Peft;
return MergeSchedulerKind::Peft; case MergeSchedulerDcp: return MergeSchedulerKind::Dcp;
case MergeSchedulerDcp:
return MergeSchedulerKind::Dcp;
} }
llvm_unreachable("unknown merge scheduler kind"); llvm_unreachable("unknown merge scheduler kind");
} }
@@ -115,19 +113,16 @@ MergeScheduleResult MergeSchedulingAnalysis::run() {
MergeScheduleResult schedule; MergeScheduleResult schedule;
if (options.kind == MergeSchedulerKind::Peft) { if (options.kind == MergeSchedulerKind::Peft) {
schedule = runPeftScheduler( schedule = runPeftScheduler(graph,
graph, PeftScheduleOptions {options.processorCount,
PeftScheduleOptions {options.processorCount, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()), static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
entryOp->getContext()}); entryOp->getContext()});
} }
else { else {
schedule = runDcpScheduler( schedule = runDcpScheduler(graph,
graph, DcpScheduleOptions {options.processorCount,
DcpScheduleOptions {
options.processorCount,
dcpCriticalWindowSize.getValue(), dcpCriticalWindowSize.getValue(),
options.allowDcpFallbackForAutoCoreCount options.allowDcpFallbackForAutoCoreCount},
},
entryOp->getContext()); entryOp->getContext());
} }
+4 -3
View File
@@ -8,8 +8,8 @@
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -211,8 +211,9 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) { if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics); (void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
(void) withScalarCoreFromBatchLane( (void) withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) {
coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) { return verifyCoreOperands(scalarCore, diagnostics); }); return verifyCoreOperands(scalarCore, diagnostics);
});
continue; continue;
} }