diff --git a/src/PIM/Common/Support/Diagnostics.hpp b/src/PIM/Common/Support/Diagnostics.hpp index 5d84111..11e3d78 100644 --- a/src/PIM/Common/Support/Diagnostics.hpp +++ b/src/PIM/Common/Support/Diagnostics.hpp @@ -13,7 +13,8 @@ namespace onnx_mlir::pim { struct CappedDiagnosticReporter { - explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {} + explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) + : maxReportedFailures(maxReportedFailures) {} template void report(mlir::Operation* op, EmitFn&& emit) { @@ -24,8 +25,7 @@ struct CappedDiagnosticReporter { void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const { if (numFailures > maxReportedFailures) - op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " - << failureDescription; + op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription; } bool hasFailure() const { return numFailures != 0; } diff --git a/src/PIM/Compiler/PimCompilerOptions.cpp b/src/PIM/Compiler/PimCompilerOptions.cpp index 37410ed..515bff5 100644 --- a/src/PIM/Compiler/PimCompilerOptions.cpp +++ b/src/PIM/Compiler/PimCompilerOptions.cpp @@ -1,7 +1,7 @@ -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" - #include "llvm/Support/ErrorHandling.h" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" + #define DEBUG_TYPE "PimCompilerOptions" namespace onnx_mlir { @@ -15,13 +15,13 @@ llvm::cl::opt pimEmissionTarget( llvm::cl::init(EmitPimCodegen), llvm::cl::cat(OnnxMlirOptions)); -llvm::cl::opt pimMergeScheduler( - "pim-merge-scheduler", - llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"), - llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")), - llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")), - llvm::cl::init(MergeSchedulerPeft), - llvm::cl::cat(OnnxMlirOptions)); +llvm::cl::opt + pimMergeScheduler("pim-merge-scheduler", + llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"), + llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")), + llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")), + llvm::cl::init(MergeSchedulerPeft), + llvm::cl::cat(OnnxMlirOptions)); llvm::cl::opt pimOnlyCodegen("pim-only-codegen", diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp index e78d067..7355c54 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp @@ -99,15 +99,17 @@ auto createSpatCompute(RewriterT& rewriter, using BodyResult = detail::InvokeWithBlockArgsResultT, std::make_index_sequence>; if constexpr (std::is_same_v) { - detail::invokeWithValues( - std::forward(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence {}); + detail::invokeWithValues(std::forward(body), + detail::getInputBlockArgs(block, weights.size()), + std::make_index_sequence {}); rewriter.setInsertionPointAfter(computeOp); return computeOp; } else { - auto bodyResult = detail::invokeWithValues( - std::forward(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence {}); + auto bodyResult = detail::invokeWithValues(std::forward(body), + detail::getInputBlockArgs(block, weights.size()), + std::make_index_sequence {}); if (mlir::failed(bodyResult)) { rewriter.setInsertionPointAfter(computeOp); rewriter.eraseOp(computeOp); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index ec79393..707b71e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -423,8 +423,11 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, SmallVector vmmOutputs; vmmOutputs.reserve(aHSlices[coreId].size()); for (auto aHSliceId : llvm::seq(0, aHSlices[coreId].size())) - vmmOutputs.push_back(spatial::SpatVMMOp::create( - rewriter, gemmLoc, currOutHSliceType, computeOp.getWeightArgument(aHSliceId), computeOp.getInputArgument(aHSliceId))); + vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, + gemmLoc, + currOutHSliceType, + computeOp.getWeightArgument(aHSliceId), + computeOp.getInputArgument(aHSliceId))); if (vmmOutputs.empty()) { gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs"); return failure(); @@ -579,8 +582,8 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp, rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); SmallVector outputOffsets {lane, rewriter.getIndexAttr(0)}; SmallVector outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))}; - tensor::ParallelInsertSliceOp::create(rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, - unitStrides); + tensor::ParallelInsertSliceOp::create( + rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, unitStrides); rewriter.setInsertionPointAfter(batchOp); rewriter.replaceOp(gemmOp, batchOp.getResults()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index dd1d227..03a972a 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -38,23 +38,16 @@ static FailureOr> inferSupportedBatchShape(ArrayRef(lhsBatchShape.begin(), lhsBatchShape.end()); } -static Value collapseBatchDims(Value value, - int64_t batchSize, - int64_t rows, - int64_t cols, - PatternRewriter& rewriter, - Location loc) { +static Value +collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); if (type.getRank() == 2 || type.getRank() == 3) return value; - auto collapsedType = - RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding()); - SmallVector reassociation = { - ReassociationIndices {}, - ReassociationIndices {static_cast(type.getRank() - 2)}, - ReassociationIndices {static_cast(type.getRank() - 1)} - }; + auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding()); + SmallVector reassociation = {ReassociationIndices {}, + ReassociationIndices {static_cast(type.getRank() - 2)}, + ReassociationIndices {static_cast(type.getRank() - 1)}}; for (int64_t dim = 0; dim < type.getRank() - 2; ++dim) reassociation.front().push_back(dim); @@ -72,19 +65,14 @@ static Value collapseBatchDims(Value value, return collapseCompute.getResult(0); } -static Value expandBatchDims(Value value, - RankedTensorType outputType, - size_t batchRank, - PatternRewriter& rewriter, - Location loc) { +static Value +expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) { if (cast(value.getType()) == outputType) return value; - SmallVector reassociation = { - ReassociationIndices {}, - ReassociationIndices {static_cast(batchRank)}, - ReassociationIndices {static_cast(batchRank + 1)} - }; + SmallVector reassociation = {ReassociationIndices {}, + ReassociationIndices {static_cast(batchRank)}, + ReassociationIndices {static_cast(batchRank + 1)}}; for (size_t dim = 0; dim < batchRank; ++dim) reassociation.front().push_back(static_cast(dim)); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp index 0749c20..be38062 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp @@ -58,24 +58,21 @@ static Value buildNearestResizeLoop(Value input, Value outputC = channelLoop.getInductionVar(); Value outputChannelAcc = channelLoop.getRegionIterArgs().front(); - Value inputC = - buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc); + Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc); auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc}); rewriter.setInsertionPointToStart(heightLoop.getBody()); Value outputH = heightLoop.getInductionVar(); Value outputHeightAcc = heightLoop.getRegionIterArgs().front(); - Value inputH = - buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc); + Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc); auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc}); rewriter.setInsertionPointToStart(widthLoop.getBody()); Value outputW = widthLoop.getInductionVar(); Value outputWidthAcc = widthLoop.getRegionIterArgs().front(); - Value inputW = - buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc); + Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc); SmallVector inputOffsets = {inputN, inputC, inputH, inputW}; Value inputSlice = @@ -114,8 +111,8 @@ struct Resize : OpConversionPattern { if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric" || resizeOp.getNearestMode() != "floor") - return rewriter.notifyMatchFailure( - resizeOp, "resize lowering currently supports only nearest + asymmetric + floor."); + return rewriter.notifyMatchFailure(resizeOp, + "resize lowering currently supports only nearest + asymmetric + floor."); if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; }) || llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; })) diff --git a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp index 8b09290..8a41f24 100644 --- a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp @@ -94,8 +94,8 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern(newWeights.size()), static_cast(newInputs.size())}); rewriter.setInsertionPointToStart(newBlock); @@ -228,7 +228,8 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern(0, compute.getNumResults())) - mapper.map(compute.getOutputArgument(resultIndex), newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex)); + mapper.map(compute.getOutputArgument(resultIndex), + newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex)); for (Operation& op : oldBlock) rewriter.clone(op, mapper); diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index 458f521..8cacfb2 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -1,8 +1,8 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" @@ -141,8 +141,8 @@ static Value createHostTargetOffset(IRRewriter& rewriter, scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast(offset)), scale); } - totalOffset = totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() - : scaledOffset; + totalOffset = + totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() : scaledOffset; } if (!totalOffset) diff --git a/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp b/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp index bc1e577..655aef3 100644 --- a/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp +++ b/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp @@ -30,8 +30,8 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter, Value replacement) { Block& body = owner->getRegion(0).front(); BlockArgument bodyArgument = isa(owner) - ? cast(owner).getInputArgument(inputIndex) - : cast(owner).getInputArgument(inputIndex); + ? cast(owner).getInputArgument(inputIndex) + : cast(owner).getInputArgument(inputIndex); unsigned bodyArgIndex = bodyArgument.getArgNumber(); rewriter.startOpModification(owner); diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp index 989dfea..171dde5 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -233,10 +233,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp if (!computeOp.getWeights().empty()) computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end()); rewriter.setInsertionPointAfter(computeOp); - auto coreOp = PimCoreOp::create(rewriter, - loc, - ValueRange(computeWeights), - rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId))); + auto coreOp = PimCoreOp::create( + rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId))); rewriter.setInsertionPointToStart(&block); auto& coreOpBlocks = coreOp.getBody().getBlocks(); for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) { diff --git a/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp b/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp index b0e5af6..c406f43 100644 --- a/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp +++ b/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp @@ -217,8 +217,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern( - patterns.getContext()); + patterns.add(patterns.getContext()); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp index 82a7aef..859e4af 100644 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp @@ -546,11 +546,8 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low return ReturnPathLoweringResult::NotReturnPath; } -raptor::SpatialToPimPass::ReturnPathLoweringResult -raptor::SpatialToPimPass::lowerComputeResultReturnPath(spatial::SpatCompute computeOp, - OpResult result, - Value yieldValue, - IRRewriter& rewriter) { +raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath( + spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) { return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter); } diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 5130722..de992da 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -318,7 +318,7 @@ void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp f } LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, - IRRewriter& rewriter) { + IRRewriter& rewriter) { Location loc = funcOp.getLoc(); OperationFolder constantFolder(funcOp.getContext()); diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp index 7e17fd0..010aadc 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.hpp @@ -39,13 +39,10 @@ private: size_t coreId = 0; llvm::SmallVector operationsToRemove; - mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, - mlir::IRRewriter& rewriter); - mlir::LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, - mlir::IRRewriter& rewriter, - mlir::OperationFolder& constantFolder); - mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, - mlir::IRRewriter& rewriter); + mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter); + mlir::LogicalResult + lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder); + mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter); enum class ReturnPathLoweringResult { Handled, diff --git a/src/PIM/Dialect/Pim/PimOps.cpp b/src/PIM/Dialect/Pim/PimOps.cpp index 57bb1bc..a6d06d6 100644 --- a/src/PIM/Dialect/Pim/PimOps.cpp +++ b/src/PIM/Dialect/Pim/PimOps.cpp @@ -1,7 +1,7 @@ -#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" - #include +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + using namespace mlir; namespace onnx_mlir { diff --git a/src/PIM/Dialect/Pim/PimOpsAsm.cpp b/src/PIM/Dialect/Pim/PimOpsAsm.cpp index 283773f..2bd4df5 100644 --- a/src/PIM/Dialect/Pim/PimOpsAsm.cpp +++ b/src/PIM/Dialect/Pim/PimOpsAsm.cpp @@ -56,7 +56,8 @@ static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImplemitOpError() << kind << " body may only directly reference external constants"; + InFlightDiagnostic diagnostic = ownerOp->emitOpError() + << kind << " body may only directly reference external constants"; diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName(); hasFailure = true; @@ -153,10 +153,9 @@ LogicalResult PimCoreOp::verify() { Block& block = getBody().front(); if (block.getNumArguments() != getWeights().size()) return emitError("core body must have one block argument per weight"); - for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { + for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) if (getWeightArgument(weightIndex).getType() != weight.getType()) return emitError("core weight block argument types must match weight operand types exactly"); - } return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core"); } @@ -169,14 +168,12 @@ LogicalResult PimCoreBatchOp::verify() { return emitError("core_batch body must have lane, weight, and input block arguments"); if (!getLaneArgument().getType().isIndex()) return emitError("core_batch first block argument must have index type"); - for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { + for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) if (getWeightArgument(weightIndex).getType() != weight.getType()) return emitError("core_batch weight block argument types must match weight operand types exactly"); - } - for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { + for (auto [inputIndex, input] : llvm::enumerate(getInputs())) if (getInputArgument(inputIndex).getType() != input.getType()) return emitError("core_batch input block argument types must match input operand types exactly"); - } return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch"); } diff --git a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp b/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp index a1c6f0a..45d7a70 100644 --- a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp +++ b/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp @@ -50,10 +50,9 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap(user)) { - for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) { + for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) if (initArg == value) pendingValues.push_back(forOp.getResult(index)); - } } if (auto dpsOp = dyn_cast(user)) { diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index 43aff68..41d3146 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -1,7 +1,7 @@ -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" - #include +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + using namespace mlir; namespace onnx_mlir { @@ -65,9 +65,7 @@ void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) { OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); } -llvm::iterator_range SpatInParallelOp::getYieldingOps() { - return getRegion().front().getOperations(); -} +llvm::iterator_range SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); } void SpatialDialect::initialize() { addTypes< diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp index 5e517fe..d3e49ce 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -104,16 +104,13 @@ static ParseResult parseBoundValueList(OpAsmParser& parser, return failure(); auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult { switch (currentDelimiter) { - case ListDelimiter::Paren: - return parser.parseRParen(); - case ListDelimiter::Square: - return parser.parseRSquare(); + case ListDelimiter::Paren: return parser.parseRParen(); + case ListDelimiter::Square: return parser.parseRSquare(); } llvm_unreachable("unsupported delimiter"); }; - if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) { + if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) return failure(); - } return success(); } diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index 8efd47d..e753e41 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -272,9 +272,9 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region continue; InFlightDiagnostic diagnostic = ownerOp->emitOpError() - << kind << " body may only directly reference external constants"; - diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber() - << " is used by " << op->getName(); + << kind << " body may only directly reference external constants"; + diagnostic.attachNote(op->getLoc()) + << "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName(); hasFailure = true; } }); @@ -457,14 +457,12 @@ LogicalResult SpatCompute::verify() { if (block.getNumArguments() != expectedArgCount) return emitError("compute body must have weight and input block arguments"); - for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { + for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) if (getWeightArgument(weightIndex).getType() != weight.getType()) return emitError("compute weight block argument types must match weight operand types exactly"); - } - for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { + for (auto [inputIndex, input] : llvm::enumerate(getInputs())) if (getInputArgument(inputIndex).getType() != input.getType()) return emitError("compute input block argument types must match input operand types exactly"); - } if (block.mightHaveTerminator()) { auto yieldOp = dyn_cast_or_null(block.getTerminator()); @@ -582,10 +580,9 @@ LogicalResult SpatComputeBatch::verify() { if (!getLaneArgument().getType().isIndex()) return emitError("compute_batch first block argument must have index type"); - for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { + for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) if (getWeightArgument(weightIndex).getType() != weight.getType()) return emitError("compute_batch weight block argument types must match weight operand types exactly"); - } for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { BlockArgument blockArg = getInputArgument(inputIndex); if (blockArg.getType() != input.getType()) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp index 45d5fb0..cfc03d3 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp @@ -1,6 +1,6 @@ -#include "DCPAnalysis.hpp" #include "../Scheduling/ComputeGraph.hpp" #include "../Scheduling/DcpScheduler.hpp" +#include "DCPAnalysis.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" namespace onnx_mlir { diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp index 580616c..eec2d6b 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp @@ -11,15 +11,15 @@ using DCPAnalysisResult = MergeScheduleResult; struct DCPAnalysis { private: DCPAnalysisResult result; - mlir::Operation *entryOp; + mlir::Operation* entryOp; DCPAnalysisResult run(); public: - DCPAnalysis(mlir::Operation *op) + DCPAnalysis(mlir::Operation* op) : entryOp(op) { result = run(); } - DCPAnalysisResult &getResult() { return result; } + DCPAnalysisResult& getResult() { return result; } }; } // namespace spatial diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index dded1d0..54da355 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -1,5 +1,3 @@ -#include "MaterializeMergeSchedule.hpp" - #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -24,6 +22,7 @@ #include #include +#include "MaterializeMergeSchedule.hpp" #include "Scheduling/ComputeInstanceUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" @@ -72,9 +71,7 @@ struct CpuSlotKey { }; struct CpuSlotKeyInfo { - static CpuSlotKey getEmptyKey() { - return {std::numeric_limits::max(), std::numeric_limits::max()}; - } + static CpuSlotKey getEmptyKey() { return {std::numeric_limits::max(), std::numeric_limits::max()}; } static CpuSlotKey getTombstoneKey() { return {std::numeric_limits::max() - 1, std::numeric_limits::max()}; @@ -139,11 +136,12 @@ struct MaterializerState { DenseMap hostReplacements; DenseSet oldComputeOps; - MaterializerState(func::FuncOp func, - const MergeScheduleResult& schedule, - int64_t& nextChannelId) - : func(func), schedule(schedule), rewriter(func.getContext()), constantFolder(func.getContext()), - nextChannelId(nextChannelId) {} + MaterializerState(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) + : func(func), + schedule(schedule), + rewriter(func.getContext()), + constantFolder(func.getContext()), + nextChannelId(nextChannelId) {} }; bool isConstantLike(Value value) { @@ -189,11 +187,12 @@ std::optional getConstantFirstSliceOffset(tensor::ExtractSliceOp extra return std::nullopt; } -ProducerKey getBatchLaneProducerKey(SpatComputeBatch batch, - uint32_t laneStart, - uint32_t laneCount, - size_t resultIndex) { - return {{batch.getOperation(), laneStart, laneCount}, resultIndex}; +ProducerKey +getBatchLaneProducerKey(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount, size_t resultIndex) { + return { + {batch.getOperation(), laneStart, laneCount}, + resultIndex + }; } ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex) { @@ -202,8 +201,8 @@ ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex) bool isWholeBatchProducerKey(ProducerKey key) { auto batch = dyn_cast_or_null(key.instance.op); - return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0 && - key.instance.laneCount == static_cast(batch.getLaneCount()); + return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0 + && key.instance.laneCount == static_cast(batch.getLaneCount()); } SmallVector expandWholeBatchProducerKey(ProducerKey key) { @@ -306,7 +305,10 @@ std::optional getProducerKey(Value value, const ComputeInstance* co auto result = dyn_cast(value); if (!result) return std::nullopt; - return ProducerKey {{compute.getOperation(), 0, 1}, result.getResultNumber()}; + return ProducerKey { + {compute.getOperation(), 0, 1}, + result.getResultNumber() + }; } if (auto batch = dyn_cast(definingOp)) { @@ -316,10 +318,8 @@ std::optional getProducerKey(Value value, const ComputeInstance* co if (batch.getNumResults() != 0) { if (consumerInstance && isa(consumerInstance->op)) - return getBatchLaneProducerKey(batch, - consumerInstance->laneStart, - consumerInstance->laneCount, - result.getResultNumber()); + return getBatchLaneProducerKey( + batch, consumerInstance->laneStart, consumerInstance->laneCount, result.getResultNumber()); return getWholeBatchProducerKey(batch, result.getResultNumber()); } @@ -489,12 +489,13 @@ void createEmptyMaterializedOps(MaterializerState& state) { continue; } - auto batch = SpatComputeBatch::create(state.rewriter, - loc, - TypeRange(resultTypes), - state.rewriter.getI32IntegerAttr(static_cast(materializedClass.cpus.size())), - ValueRange {}, - ValueRange {}); + auto batch = + SpatComputeBatch::create(state.rewriter, + loc, + TypeRange(resultTypes), + state.rewriter.getI32IntegerAttr(static_cast(materializedClass.cpus.size())), + ValueRange {}, + ValueRange {}); batch.getProperties().setOperandSegmentSizes({0, 0}); SmallVector coreIds; coreIds.reserve(materializedClass.cpus.size()); @@ -506,15 +507,13 @@ void createEmptyMaterializedOps(MaterializerState& state) { SmallVector blockArgLocs {loc}; llvm::append_range(blockArgTypes, resultTypes); blockArgLocs.append(resultTypes.size(), loc); - Block* body = state.rewriter.createBlock( - &batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + Block* body = + state.rewriter.createBlock(&batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); state.rewriter.setInsertionPointToEnd(body); - if (resultTypes.empty()) { + if (resultTypes.empty()) SpatYieldOp::create(state.rewriter, loc, ValueRange {}); - } - else { + else SpatInParallelOp::create(state.rewriter, loc); - } materializedClass.op = batch.getOperation(); materializedClass.body = body; state.rewriter.setInsertionPointAfter(batch.getOperation()); @@ -559,7 +558,8 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali else { cast(materializedClass.op).getInputsMutable().append(ValueRange(input)); setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size()); - BlockArgument arg = materializedClass.body->insertArgument(materializedClass.body->getNumArguments()-1, input.getType(), input.getLoc()); + BlockArgument arg = materializedClass.body->insertArgument( + materializedClass.body->getNumArguments() - 1, input.getType(), input.getLoc()); materializedClass.inputArgs[input] = arg; return arg; } @@ -586,9 +586,8 @@ SmallVector createIndexConstants(MaterializerState& state, Operation* return createIndexConstants(state, anchor, ArrayRef(widened)); } -FailureOr> getPeerInstances(MaterializerState& state, - const MaterializedClass& materializedClass, - SlotId slot) { +FailureOr> +getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) { SmallVector peers; peers.reserve(materializedClass.cpus.size()); for (CpuId cpu : materializedClass.cpus) { @@ -601,9 +600,9 @@ FailureOr> getPeerInstances(MaterializerState& s } Value createOriginalLaneValue(MaterializerState& state, - MaterializedClass& materializedClass, - ArrayRef peers, - Location loc) { + MaterializedClass& materializedClass, + ArrayRef peers, + Location loc) { assert(!peers.empty() && "expected at least one peer instance"); if (!materializedClass.isBatch) return createIndexConstant(state, materializedClass.op, peers.front().laneStart); @@ -751,12 +750,12 @@ SmallVector getSortedDestinationClasses(MaterializerState& state, Pr } Value appendReceive(MaterializerState& state, - MaterializedClass& targetClass, - Type type, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - Location loc) { + MaterializedClass& targetClass, + Type type, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); SmallVector channelIdValues = createIndexConstants(state, targetClass.op, channelIds); SmallVector sourceCoreIdValues = createIndexConstants(state, targetClass.op, sourceCoreIds); @@ -774,22 +773,18 @@ Value appendReceive(MaterializerState& state, .getOutput(); } - return SpatChannelReceiveOp::create(state.rewriter, - loc, - type, - channelIdValues.front(), - sourceCoreIdValues.front(), - targetCoreIdValues.front()) + return SpatChannelReceiveOp::create( + state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front()) .getOutput(); } Value appendHostReceive(MaterializerState& state, - MaterializedClass& sourceClass, - Type type, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - Location loc) { + MaterializedClass& sourceClass, + Type type, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { state.rewriter.setInsertionPointAfter(sourceClass.op); SmallVector channelIdValues = createIndexConstants(state, sourceClass.op, channelIds); SmallVector sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds); @@ -802,19 +797,13 @@ Value appendHostReceive(MaterializerState& state, } assert(channelIds.size() == 1 && "scalar host receive expects one channel"); - return SpatChannelReceiveOp::create(state.rewriter, - loc, - type, - channelIdValues.front(), - sourceCoreIdValues.front(), - targetCoreIdValues.front()) + return SpatChannelReceiveOp::create( + state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front()) .getOutput(); } -LogicalResult setHostOutputValue(MaterializerState& state, - MaterializedClass& sourceClass, - Value originalOutput, - Value payload) { +LogicalResult +setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) { auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput); if (resultIt == sourceClass.hostOutputToResultIndex.end()) return sourceClass.op->emitError("missing host result slot for materialized output"); @@ -864,12 +853,12 @@ LogicalResult setHostOutputValue(MaterializerState& state, } void appendScalarSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - int64_t channelId, - int32_t sourceCoreId, - int32_t targetCoreId, - Location loc) { + MaterializedClass& sourceClass, + Value payload, + int64_t channelId, + int32_t sourceCoreId, + int32_t targetCoreId, + Location loc) { state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId); Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId); @@ -878,12 +867,12 @@ void appendScalarSend(MaterializerState& state, } void appendBatchSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - Location loc) { + MaterializedClass& sourceClass, + Value payload, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); SmallVector channelIdValues = createIndexConstants(state, sourceClass.op, channelIds); SmallVector sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds); @@ -892,11 +881,11 @@ void appendBatchSend(MaterializerState& state, } LogicalResult emitClassToClassCommunication(MaterializerState& state, - MaterializedClass& sourceClass, - MaterializedClass& targetClass, - ArrayRef keys, - Value payload, - Location loc) { + MaterializedClass& sourceClass, + MaterializedClass& targetClass, + ArrayRef keys, + Value payload, + Location loc) { if (sourceClass.id == targetClass.id) { for (ProducerKey key : keys) state.availableValues[key][targetClass.id] = payload; @@ -908,8 +897,13 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, int32_t sourceCpu = static_cast(sourceClass.cpus.front()); int32_t targetCpu = static_cast(targetClass.cpus.front()); appendScalarSend(state, sourceClass, payload, channelId, sourceCpu, targetCpu, loc); - Value received = appendReceive(state, targetClass, payload.getType(), ArrayRef(channelId), - ArrayRef(sourceCpu), ArrayRef(targetCpu), loc); + Value received = appendReceive(state, + targetClass, + payload.getType(), + ArrayRef(channelId), + ArrayRef(sourceCpu), + ArrayRef(targetCpu), + loc); for (ProducerKey key : keys) state.availableValues[key][targetClass.id] = received; return success(); @@ -937,7 +931,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, loc); } - Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); + Value received = + appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); for (ProducerKey key : keys) state.availableValues[key][targetClass.id] = received; return success(); @@ -946,8 +941,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, if (sourceClass.isBatch && !targetClass.isBatch) { std::optional packedKey = getContiguousProducerKeyForKeys(keys); if (!packedKey) - return sourceClass.op->emitError( - "cannot materialize batch-to-scalar communication as concat because source lanes are not contiguous in send order"); + return sourceClass.op->emitError("cannot materialize batch-to-scalar communication as concat because source " + "lanes are not contiguous in send order"); FailureOr packedType = getPackedBatchTensorType(payload.getType(), keys.size()); if (failed(packedType)) @@ -975,7 +970,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, if (sourceClass.isBatch && targetClass.isBatch) { if (sourceClass.cpus.size() != targetClass.cpus.size()) - return sourceClass.op->emitError("cannot materialize batch communication between equivalence classes of different sizes"); + return sourceClass.op->emitError( + "cannot materialize batch communication between equivalence classes of different sizes"); SmallVector channelIds; SmallVector sourceCoreIds; @@ -991,7 +987,8 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, } appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); - Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); + Value received = + appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); for (ProducerKey key : keys) state.availableValues[key][targetClass.id] = received; return success(); @@ -1001,11 +998,11 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, } LogicalResult emitHostCommunication(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value payload, - Value originalOutput, - Location loc) { + MaterializedClass& sourceClass, + ArrayRef keys, + Value payload, + Value originalOutput, + Location loc) { if (!hasLiveExternalUse(originalOutput, state.oldComputeOps)) return success(); @@ -1025,23 +1022,25 @@ LogicalResult emitHostCommunication(MaterializerState& state, } appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); - Value received = appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); + Value received = + appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); state.hostReplacements[originalOutput] = received; return success(); } LogicalResult emitOutputFanout(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value payload, - Value originalOutput, - Location loc) { + MaterializedClass& sourceClass, + ArrayRef keys, + Value payload, + Value originalOutput, + Location loc) { if (keys.empty()) return success(); if (!sourceClass.isBatch) { for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front())) - if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) + if (failed( + emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) return failure(); if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc))) return failure(); @@ -1065,11 +1064,8 @@ LogicalResult emitOutputFanout(MaterializerState& state, return success(); } -FailureOr materializeWholeBatchInput(MaterializerState& state, - MaterializedClass& targetClass, - ProducerKey key, - Type resultType, - Location loc) { +FailureOr materializeWholeBatchInput( + MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) { auto batch = dyn_cast_or_null(key.instance.op); auto resultTensorType = dyn_cast(resultType); if (!batch || !resultTensorType || resultTensorType.getRank() == 0) @@ -1115,9 +1111,9 @@ FailureOr materializeWholeBatchInput(MaterializerState& state, } FailureOr resolveInputValue(MaterializerState& state, - MaterializedClass& targetClass, - Value input, - const ComputeInstance& consumerInstance) { + MaterializedClass& targetClass, + Value input, + const ComputeInstance& consumerInstance) { if (isConstantLike(input)) return input; @@ -1135,9 +1131,9 @@ FailureOr resolveInputValue(MaterializerState& state, } void mapWeights(MaterializerState& state, - MaterializedClass& targetClass, - const ComputeInstance& instance, - IRMapping& mapper) { + MaterializedClass& targetClass, + const ComputeInstance& instance, + IRMapping& mapper) { Operation* op = instance.op; if (auto compute = dyn_cast(op)) { for (auto [index, weight] : llvm::enumerate(compute.getWeights())) @@ -1151,9 +1147,9 @@ void mapWeights(MaterializerState& state, } LogicalResult mapInputs(MaterializerState& state, - MaterializedClass& targetClass, - const ComputeInstance& instance, - IRMapping& mapper) { + MaterializedClass& targetClass, + const ComputeInstance& instance, + IRMapping& mapper) { Operation* op = instance.op; if (auto compute = dyn_cast(op)) { for (auto [index, input] : llvm::enumerate(compute.getInputs())) { @@ -1200,9 +1196,8 @@ SmallVector collectMappedBatchOutputs(SpatComputeBatch batch, IRMappin return outputs; } -FailureOr> cloneInstanceBody(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef peers) { +FailureOr> +cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef peers) { assert(!peers.empty() && "expected at least one peer instance"); const ComputeInstance& instance = peers.front(); Operation* sourceOp = instance.op; @@ -1318,9 +1313,8 @@ LogicalResult eraseOldComputeOps(MaterializerState& state) { } // namespace -LogicalResult MergeScheduleMaterializer::run(func::FuncOp func, - const MergeScheduleResult& schedule, - int64_t& nextChannelId) { +LogicalResult +MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) { if (schedule.dominanceOrderCompute.empty()) return success(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.hpp index 7bbc7a2..3663086 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.hpp @@ -10,8 +10,7 @@ namespace spatial { class MergeScheduleMaterializer { public: - mlir::LogicalResult - run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId); + mlir::LogicalResult run(mlir::func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId); }; } // namespace spatial diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 7e503e1..13b8d72 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -57,8 +57,7 @@ bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != n class ScopedMergePhaseTimer { public: explicit ScopedMergePhaseTimer(StringRef phaseName) - : enabled(isMergeProfilingEnabled()), - phase(phaseName.str()) { + : enabled(isMergeProfilingEnabled()), phase(phaseName.str()) { if (enabled) start = std::chrono::steady_clock::now(); } @@ -130,15 +129,12 @@ void emitMergeIrCounts(StringRef phaseName, func::FuncOp funcOp) { MergeIrCounts counts = collectMergeIrCounts(funcOp); llvm::errs() << "[merge-profile] " << phaseName << " counts:" - << " compute=" << counts.topLevelComputeCount - << " compute_batch=" << counts.topLevelComputeBatchCount + << " compute=" << counts.topLevelComputeCount << " compute_batch=" << counts.topLevelComputeBatchCount << " scalar_send=" << counts.scalarChannelSendCount << " scalar_recv=" << counts.scalarChannelReceiveCount << " tensor_send=" << counts.tensorChannelSendCount - << " tensor_recv=" << counts.tensorChannelReceiveCount - << " wvmm=" << counts.wvmmCount - << " vadd=" << counts.vaddCount - << " scf_for=" << counts.scfForCount << "\n"; + << " tensor_recv=" << counts.tensorChannelReceiveCount << " wvmm=" << counts.wvmmCount + << " vadd=" << counts.vaddCount << " scf_for=" << counts.scfForCount << "\n"; } static std::optional getComputeCoreId(SpatCompute compute) { @@ -167,7 +163,8 @@ bool isTrivialSerialMergeCandidate(SpatCompute compute) { return user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size(); } -SmallVector appendMissingWeightsAndBuildIndexMap(SmallVectorImpl& targetWeights, ValueRange sourceWeights) { +SmallVector appendMissingWeightsAndBuildIndexMap(SmallVectorImpl& targetWeights, + ValueRange sourceWeights) { DenseMap> targetWeightIndices; for (auto [weightIndex, weight] : llvm::enumerate(targetWeights)) targetWeightIndices[weight].push_back(weightIndex); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.hpp index b3638b7..2797fc2 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.hpp @@ -7,6 +7,6 @@ namespace onnx_mlir { -mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t &nextChannelId); +mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t& nextChannelId); } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp index e56a36a..a13c66f 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp @@ -707,8 +707,10 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc()); if (packedInput) { SmallVector channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder); - SmallVector sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder); - SmallVector targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder); + SmallVector sourceCoreIdValues = + createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder); + SmallVector targetCoreIdValues = + createIndexConstants(run.ops.front(), targetCoreIds, constantFolder); spatial::SpatChannelSendTensorOp::create( rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput); for (auto op : run.ops) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp index 5481a76..d72e17d 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp @@ -39,11 +39,11 @@ struct ComputeGraph { llvm::DenseMap instanceToIndex; }; -ComputeGraph buildComputeGraph(mlir::Operation *entryOp); -bool verifyAcyclic(const ComputeGraph &graph); +ComputeGraph buildComputeGraph(mlir::Operation* entryOp); +bool verifyAcyclic(const ComputeGraph& graph); -Weight getComputeInstanceWeight(const ComputeInstance &instance); -CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance); +Weight getComputeInstanceWeight(const ComputeInstance& instance); +CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance); } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstance.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstance.hpp index 2d160d5..dcac91a 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstance.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstance.hpp @@ -11,11 +11,11 @@ namespace onnx_mlir { namespace spatial { struct ComputeInstance { - mlir::Operation *op = nullptr; + mlir::Operation* op = nullptr; uint32_t laneStart = 0; uint32_t laneCount = 1; - bool operator==(const ComputeInstance &other) const { + bool operator==(const ComputeInstance& other) const { return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount; } }; @@ -29,16 +29,15 @@ namespace llvm { template <> struct DenseMapInfo { static onnx_mlir::spatial::ComputeInstance getEmptyKey() { - return {DenseMapInfo::getEmptyKey(), UINT32_MAX, UINT32_MAX}; + return {DenseMapInfo::getEmptyKey(), UINT32_MAX, UINT32_MAX}; } static onnx_mlir::spatial::ComputeInstance getTombstoneKey() { - return {DenseMapInfo::getTombstoneKey(), UINT32_MAX, UINT32_MAX}; + return {DenseMapInfo::getTombstoneKey(), UINT32_MAX, UINT32_MAX}; } - static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance &value) { + static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance& value) { return llvm::hash_combine(value.op, value.laneStart, value.laneCount); } - static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs, - const onnx_mlir::spatial::ComputeInstance &rhs) { + static bool isEqual(const onnx_mlir::spatial::ComputeInstance& lhs, const onnx_mlir::spatial::ComputeInstance& rhs) { return lhs == rhs; } }; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.hpp index 90af9d7..c257a1c 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.hpp @@ -27,15 +27,15 @@ ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane); std::optional getProducerValueRef(mlir::Value value, - const ComputeInstance *consumerInstance = nullptr); + const ComputeInstance* consumerInstance = nullptr); std::optional getComputeProducerInstance(mlir::Value value, - const ComputeInstance *consumerInstance = nullptr); + const ComputeInstance* consumerInstance = nullptr); -llvm::SmallVector getComputeInstanceInputs(const ComputeInstance &instance); -llvm::SmallVector getComputeInstanceWeights(const ComputeInstance &instance); -llvm::SmallVector getComputeInstanceOutputValues(const ComputeInstance &instance); -llvm::SmallVector getComputeInstanceOutputTypes(const ComputeInstance &instance); -mlir::Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance); +llvm::SmallVector getComputeInstanceInputs(const ComputeInstance& instance); +llvm::SmallVector getComputeInstanceWeights(const ComputeInstance& instance); +llvm::SmallVector getComputeInstanceOutputValues(const ComputeInstance& instance); +llvm::SmallVector getComputeInstanceOutputTypes(const ComputeInstance& instance); +mlir::Block& getComputeInstanceTemplateBlock(const ComputeInstance& instance); } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp index 6b3b7fd..cd7f177 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp @@ -10,8 +10,8 @@ #include #include -#include "DcpScheduler.hpp" #include "../DCPGraph/Graph.hpp" +#include "DcpScheduler.hpp" namespace onnx_mlir { namespace spatial { @@ -47,7 +47,7 @@ struct WindowScheduleResult { size_t maxMergeGroupSize = 0; }; -size_t getSchedulingCpuBudget(const DcpScheduleOptions &options) { +size_t getSchedulingCpuBudget(const DcpScheduleOptions& options) { if (options.processorCount > 0) return options.processorCount; return std::numeric_limits::max(); @@ -72,7 +72,7 @@ std::vector aggregateEdges(llvm::ArrayRef edges) { for (auto [key, weight] : edgeWeights) aggregatedEdges.push_back( {static_cast(key.first), static_cast(key.second), static_cast(weight)}); - llvm::sort(aggregatedEdges, [](const IndexedEdge &lhs, const IndexedEdge &rhs) { + llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) { if (std::get<0>(lhs) != std::get<0>(rhs)) return std::get<0>(lhs) < std::get<0>(rhs); return std::get<1>(lhs) < std::get<1>(rhs); @@ -80,7 +80,7 @@ std::vector aggregateEdges(llvm::ArrayRef edges) { return aggregatedEdges; } -VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) { +VirtualGraph buildInitialVirtualGraph(const ComputeGraph& graph) { VirtualGraph virtualGraph; virtualGraph.nodes.reserve(graph.nodes.size()); for (auto [index, node] : llvm::enumerate(graph.nodes)) { @@ -93,14 +93,14 @@ VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) { std::vector edges; edges.reserve(graph.edges.size()); - for (const ComputeGraphEdge &edge : graph.edges) + for (const ComputeGraphEdge& edge : graph.edges) edges.push_back( {static_cast(edge.source), static_cast(edge.target), static_cast(edge.transferCost)}); virtualGraph.edges = aggregateEdges(edges); return virtualGraph; } -TimingInfo computeTiming(const VirtualGraph &graph) { +TimingInfo computeTiming(const VirtualGraph& graph) { TimingInfo timing; size_t nodeCount = graph.nodes.size(); timing.aest.assign(nodeCount, 0); @@ -122,7 +122,7 @@ TimingInfo computeTiming(const VirtualGraph &graph) { } auto getVirtualNodeOrderKey = [&](size_t nodeIndex) { - const VirtualNode &node = graph.nodes[nodeIndex]; + const VirtualNode& node = graph.nodes[nodeIndex]; if (!node.originalNodeIndices.empty()) return node.originalNodeIndices.front(); return nodeIndex; @@ -181,7 +181,7 @@ TimingInfo computeTiming(const VirtualGraph &graph) { return timing; } -std::vector> buildUndirectedAdjacency(const VirtualGraph &graph) { +std::vector> buildUndirectedAdjacency(const VirtualGraph& graph) { std::vector> adjacency(graph.nodes.size()); for (auto [start, end, weight] : graph.edges) { (void) weight; @@ -191,14 +191,14 @@ std::vector> buildUndirectedAdjacency(const VirtualGraph &gr adjacency[startIndex].push_back(endIndex); adjacency[endIndex].push_back(startIndex); } - for (auto &neighbours : adjacency) { + for (auto& neighbours : adjacency) { llvm::sort(neighbours); neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end()); } return adjacency; } -std::vector selectCriticalWindow(const VirtualGraph &graph, const TimingInfo &timing, size_t windowSize) { +std::vector selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) { std::vector ranked(timing.aest.size()); std::iota(ranked.begin(), ranked.end(), 0); auto isHigherPriority = [&](size_t lhs, size_t rhs) { @@ -240,7 +240,7 @@ std::vector selectCriticalWindow(const VirtualGraph &graph, const Timing auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); }; std::priority_queue, decltype(frontierCompare)> frontier(frontierCompare); - auto addToWindow = [&](size_t node, const std::vector &eligible) { + auto addToWindow = [&](size_t node, const std::vector& eligible) { if (inWindow[node]) return; inWindow[node] = true; @@ -288,7 +288,7 @@ std::vector selectCriticalWindow(const VirtualGraph &graph, const Timing return selected; } -std::vector buildWindowEdges(const VirtualGraph &graph, const std::vector &nodeToWindowIndex) { +std::vector buildWindowEdges(const VirtualGraph& graph, const std::vector& nodeToWindowIndex) { std::vector windowEdges; windowEdges.reserve(graph.edges.size()); for (auto [start, end, weight] : graph.edges) { @@ -301,10 +301,10 @@ std::vector buildWindowEdges(const VirtualGraph &graph, const std:: return aggregateEdges(windowEdges); } -WindowScheduleResult scheduleWindow(const VirtualGraph &graph, +WindowScheduleResult scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef selectedNodes, - const DcpScheduleOptions &options, - mlir::MLIRContext *context) { + const DcpScheduleOptions& options, + mlir::MLIRContext* context) { std::vector windowWeights; std::vector windowCrossbarUsage; std::vector windowNodeOrderKeys; @@ -338,17 +338,17 @@ WindowScheduleResult scheduleWindow(const VirtualGraph &graph, result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size()); std::vector mergeGroup; mergeGroup.reserve(scheduledTasks.size()); - for (const auto &task : scheduledTasks) + for (const auto& task : scheduledTasks) mergeGroup.push_back(selectedNodes[task.nodeIndex]); result.mergeGroups.push_back(std::move(mergeGroup)); } return result; } -bool coarsenGraph(const VirtualGraph &graph, +bool coarsenGraph(const VirtualGraph& graph, llvm::ArrayRef> mergeGroups, - VirtualGraph &coarsenedGraph, - std::vector &oldToNewNode) { + VirtualGraph& coarsenedGraph, + std::vector& oldToNewNode) { TimingInfo timing = computeTiming(graph); std::vector topologicalRank(graph.nodes.size()); std::iota(topologicalRank.begin(), topologicalRank.end(), 0); @@ -358,7 +358,7 @@ bool coarsenGraph(const VirtualGraph &graph, std::vector> orderedMergeGroups; orderedMergeGroups.reserve(mergeGroups.size()); - for (const auto &mergeGroup : mergeGroups) { + for (const auto& mergeGroup : mergeGroups) { orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end()); std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) { if (topologicalRank[lhs] != topologicalRank[rhs]) @@ -395,7 +395,7 @@ bool coarsenGraph(const VirtualGraph &graph, continue; } - auto &newNodeIndex = mergeGroupToNewNode[static_cast(mergeGroupIndex)]; + auto& newNodeIndex = mergeGroupToNewNode[static_cast(mergeGroupIndex)]; if (newNodeIndex.has_value()) { oldToNewNode[nodeIndex] = *newNodeIndex; continue; @@ -403,8 +403,9 @@ bool coarsenGraph(const VirtualGraph &graph, VirtualNode mergedNode; for (size_t memberIndex : orderedMergeGroups[static_cast(mergeGroupIndex)]) { - const VirtualNode &memberNode = graph.nodes[memberIndex]; - mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), memberNode.originalNodeIndices.end()); + const VirtualNode& memberNode = graph.nodes[memberIndex]; + mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), + memberNode.originalNodeIndices.end()); mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight); mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage); } @@ -437,7 +438,7 @@ bool coarsenGraph(const VirtualGraph &graph, return true; } -size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &options) { +size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions& options) { size_t windowSize = std::min(options.criticalWindowSize, nodeCount); CPU maxCpuCount = std::max(1, static_cast(getSchedulingCpuBudget(options))); if (nodeCount > static_cast(maxCpuCount)) @@ -445,7 +446,7 @@ size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &op return windowSize; } -void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) { +void assignFeasibleAest(const ComputeGraph& graph, MergeScheduleResult& result) { llvm::DenseMap nodeIndexByInstance; nodeIndexByInstance.reserve(graph.nodes.size()); for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes)) @@ -458,7 +459,7 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) std::vector> scheduledChildren(graph.nodes.size()); std::vector incomingEdgeCount(graph.nodes.size(), 0); - for (const ComputeGraphEdge &edge : graph.edges) { + for (const ComputeGraphEdge& edge : graph.edges) { const ComputeInstance sourceInstance = graph.nodes[edge.source].instance; const ComputeInstance targetInstance = graph.nodes[edge.target].instance; const size_t sourceCpu = result.computeToCpuMap.lookup(sourceInstance); @@ -473,15 +474,15 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) } llvm::DenseMap>> tasksByCpu; - for (const ComputeGraphNode &node : graph.nodes) { + for (const ComputeGraphNode& node : graph.nodes) { size_t cpu = result.computeToCpuMap.lookup(node.instance); size_t slot = result.computeToCpuSlotMap.lookup(node.instance); tasksByCpu[cpu].push_back({slot, nodeIndexByInstance.lookup(node.instance)}); } - for (auto &entry : tasksByCpu) { - auto &scheduledTasks = entry.second; - llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) { + for (auto& entry : tasksByCpu) { + auto& scheduledTasks = entry.second; + llvm::sort(scheduledTasks, [](const auto& lhs, const auto& rhs) { if (lhs.first != rhs.first) return lhs.first < rhs.first; return lhs.second < rhs.second; @@ -512,7 +513,7 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) readyNodes.pop(); processedNodeCount++; - for (const ScheduledEdge &edge : scheduledChildren[sourceIndex]) { + for (const ScheduledEdge& edge : scheduledChildren[sourceIndex]) { startTimes[edge.target] = std::max(startTimes[edge.target], addOrMax(startTimes[sourceIndex], edge.delay)); assert(incomingEdgeCount[edge.target] > 0 && "scheduled incoming edge count underflow"); incomingEdgeCount[edge.target]--; @@ -528,7 +529,7 @@ void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) result.computeToAestMap[node.instance] = startTimes[nodeIndex]; } -MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const ComputeGraph &originalGraph) { +MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph& graph, const ComputeGraph& originalGraph) { MergeScheduleResult result; TimingInfo timing = computeTiming(graph); @@ -542,7 +543,7 @@ MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const std::vector originalNodeToCpu(originalGraph.nodes.size(), 0); for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) { - const VirtualNode &virtualNode = graph.nodes[virtualNodeIndex]; + const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex]; for (size_t originalIndex : virtualNode.originalNodeIndices) originalNodeToCpu[originalIndex] = cpu; } @@ -556,17 +557,17 @@ MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const result.computeToCpuSlotMap[node.instance] = nextCpuSlot[cpu]++; result.cpuToLastComputeMap[cpu] = node.instance; } - for (const auto &[cpu, lastCompute] : result.cpuToLastComputeMap) + for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap) result.isLastComputeOfCpu.insert(lastCompute); assignFeasibleAest(originalGraph, result); return result; } -MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const ComputeGraph &graph) { +MergeScheduleResult buildResultFromScheduledGraph(GraphDCP& graphDCP, const ComputeGraph& graph) { MergeScheduleResult result; result.dominanceOrderCompute.reserve(graph.nodes.size()); - for (const ComputeGraphNode &node : graph.nodes) + for (const ComputeGraphNode& node : graph.nodes) result.dominanceOrderCompute.push_back(node.instance); for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) { @@ -589,7 +590,8 @@ MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const Comp return result; } -MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) { +MergeScheduleResult +runLegacyDcp(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context) { llvm::SmallVector nodeWeights; llvm::SmallVector nodeCrossbarUsage; llvm::SmallVector nodeOrderKeys; @@ -599,12 +601,12 @@ MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOpt nodeOrderKeys.reserve(graph.nodes.size()); edges.reserve(graph.edges.size()); - for (const ComputeGraphNode &node : graph.nodes) { + for (const ComputeGraphNode& node : graph.nodes) { nodeWeights.push_back(node.weight); nodeCrossbarUsage.push_back(node.crossbarUsage); nodeOrderKeys.push_back(static_cast(node.originalOrder)); } - for (const ComputeGraphEdge &edge : graph.edges) { + for (const ComputeGraphEdge& edge : graph.edges) { edges.push_back( {static_cast(edge.source), static_cast(edge.target), static_cast(edge.transferCost)}); } @@ -617,11 +619,11 @@ MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOpt return buildResultFromScheduledGraph(graphDCP, graph); } -bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOptions &options) { +bool needsExactScheduledBatches(const ComputeGraph& graph, const DcpScheduleOptions& options) { if (options.processorCount == 0 || !options.allowFallbackForAutoCoreCount) return false; size_t schedulingCpuBudget = getSchedulingCpuBudget(options); - return llvm::any_of(graph.nodes, [&](const ComputeGraphNode &node) { + return llvm::any_of(graph.nodes, [&](const ComputeGraphNode& node) { auto batch = dyn_cast(node.instance.op); return batch && static_cast(batch.getLaneCount()) > schedulingCpuBudget; }); @@ -630,7 +632,7 @@ bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOpti } // namespace MergeScheduleResult -runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) { +runDcpScheduler(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context) { if (needsExactScheduledBatches(graph, options)) return runLegacyDcp(graph, options, context); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/DcpScheduler.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/DcpScheduler.hpp index 99ebf82..e050dad 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/DcpScheduler.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/DcpScheduler.hpp @@ -15,7 +15,7 @@ struct DcpScheduleOptions { }; MergeScheduleResult -runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context); +runDcpScheduler(const ComputeGraph& graph, const DcpScheduleOptions& options, mlir::MLIRContext* context); } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp index 60c6b1c..d1efd67 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp @@ -1,13 +1,13 @@ -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" #include #include -#include "ComputeGraph.hpp" #include "../DCPGraph/DCPAnalysis.hpp" +#include "ComputeGraph.hpp" #include "DcpScheduler.hpp" #include "MergeSchedulingAnalysis.hpp" #include "PeftScheduler.hpp" @@ -20,15 +20,13 @@ namespace { MergeSchedulerKind getSchedulerKind() { switch (pimMergeScheduler.getValue()) { - case MergeSchedulerPeft: - return MergeSchedulerKind::Peft; - case MergeSchedulerDcp: - return MergeSchedulerKind::Dcp; + case MergeSchedulerPeft: return MergeSchedulerKind::Peft; + case MergeSchedulerDcp: return MergeSchedulerKind::Dcp; } llvm_unreachable("unknown merge scheduler kind"); } -void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result, CrossbarUsage crossbarCapacity) { +void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result, CrossbarUsage crossbarCapacity) { llvm::DenseMap>> tasksByCpu; tasksByCpu.reserve(result.cpuToLastComputeMap.size()); @@ -45,9 +43,9 @@ void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result {result.computeToCpuSlotMap.lookup(instance), nodeIndex}); } - for (auto &entry : tasksByCpu) { - auto &scheduledTasks = entry.second; - llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) { + for (auto& entry : tasksByCpu) { + auto& scheduledTasks = entry.second; + llvm::sort(scheduledTasks, [](const auto& lhs, const auto& rhs) { if (lhs.first != rhs.first) return lhs.first < rhs.first; return lhs.second < rhs.second; @@ -70,7 +68,7 @@ void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result llvm::report_fatal_error("merge scheduling: missing last-compute marker"); } - for (const ComputeGraphEdge &edge : graph.edges) { + for (const ComputeGraphEdge& edge : graph.edges) { const ComputeInstance source = graph.nodes[edge.source].instance; const ComputeInstance target = graph.nodes[edge.target].instance; const size_t sourceCpu = result.computeToCpuMap.lookup(source); @@ -97,8 +95,8 @@ void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result } // namespace -MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation *op) - : entryOp(op) { +MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation* op) +: entryOp(op) { result = run(); } @@ -115,20 +113,17 @@ MergeScheduleResult MergeSchedulingAnalysis::run() { MergeScheduleResult schedule; if (options.kind == MergeSchedulerKind::Peft) { - schedule = runPeftScheduler( - graph, - PeftScheduleOptions {options.processorCount, static_cast(crossbarCountInCore.getValue()), - entryOp->getContext()}); + schedule = runPeftScheduler(graph, + PeftScheduleOptions {options.processorCount, + static_cast(crossbarCountInCore.getValue()), + entryOp->getContext()}); } else { - schedule = runDcpScheduler( - graph, - DcpScheduleOptions { - options.processorCount, - dcpCriticalWindowSize.getValue(), - options.allowDcpFallbackForAutoCoreCount - }, - entryOp->getContext()); + schedule = runDcpScheduler(graph, + DcpScheduleOptions {options.processorCount, + dcpCriticalWindowSize.getValue(), + options.allowDcpFallbackForAutoCoreCount}, + entryOp->getContext()); } verifySchedule(graph, schedule, static_cast(crossbarCountInCore.getValue())); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.hpp index a36f679..53b7fe5 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.hpp @@ -22,11 +22,11 @@ struct MergeSchedulingOptions { class MergeSchedulingAnalysis { public: - explicit MergeSchedulingAnalysis(mlir::Operation *op); - MergeScheduleResult &getResult() { return result; } + explicit MergeSchedulingAnalysis(mlir::Operation* op); + MergeScheduleResult& getResult() { return result; } private: - mlir::Operation *entryOp = nullptr; + mlir::Operation* entryOp = nullptr; MergeScheduleResult result; MergeScheduleResult run(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.hpp index 476c754..b9bd170 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.hpp @@ -11,10 +11,10 @@ namespace spatial { struct PeftScheduleOptions { size_t processorCount = 0; CrossbarUsage crossbarCapacity = 0; - mlir::MLIRContext *context = nullptr; + mlir::MLIRContext* context = nullptr; }; -MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftScheduleOptions &options); +MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options); } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index 057cfe7..aa9cad4 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -8,8 +8,8 @@ #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" +#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -211,8 +211,9 @@ struct VerificationPass : PassWrapper> if (auto coreBatchOp = dyn_cast(&op)) { (void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics); for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) - (void) withScalarCoreFromBatchLane( - coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) { return verifyCoreOperands(scalarCore, diagnostics); }); + (void) withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp scalarCore) { + return verifyCoreOperands(scalarCore, diagnostics); + }); continue; }