From addfc8a86edbbda7cea1796e3ddc2b36365079b9 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 25 May 2026 21:22:08 +0200 Subject: [PATCH] remove other dead logic --- README.md | 2 +- src/PIM/Common/IR/WeightUtils.cpp | 15 +- src/PIM/Compiler/PimCompilerUtils.cpp | 4 - src/PIM/Dialect/Pim/Pim.td | 10 - src/PIM/Dialect/Pim/PimOpsAsm.cpp | 34 ---- src/PIM/Dialect/Spatial/CMakeLists.txt | 1 - src/PIM/Dialect/Spatial/Channels.cpp | 178 ------------------ src/PIM/Dialect/Spatial/Channels.hpp | 43 ----- src/PIM/Dialect/Spatial/Spatial.td | 35 ---- src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp | 77 -------- .../MergeComputeNodesPass.cpp | 11 +- src/PIM/Pass/CMakeLists.txt | 1 - src/PIM/Pass/CountInstructionPass.cpp | 59 ------ src/PIM/Pass/PIMPasses.h | 2 - 14 files changed, 10 insertions(+), 462 deletions(-) delete mode 100644 src/PIM/Dialect/Spatial/Channels.cpp delete mode 100644 src/PIM/Dialect/Spatial/Channels.hpp delete mode 100644 src/PIM/Pass/CountInstructionPass.cpp diff --git a/README.md b/README.md index b696389..e569cf6 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ Supporting pieces: core count, DCP window, experimental conv impl, concat error handling, …) and `PimCodeGen` entry points. - `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`). -- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`) +- `src/PIM/Pass` — auxiliary passes (`MessagePass`) and the `PIMPasses.h` registry used by `PimAccelerator`. - `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers dialects, passes, and plugs Raptor into the ONNX-MLIR driver. diff --git a/src/PIM/Common/IR/WeightUtils.cpp b/src/PIM/Common/IR/WeightUtils.cpp index 9ba9f21..ab2b113 100644 --- a/src/PIM/Common/IR/WeightUtils.cpp +++ b/src/PIM/Common/IR/WeightUtils.cpp @@ -19,23 +19,21 @@ void markWeightAlways(mlir::Operation* op) { namespace { -template -bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { +template +bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { auto weightArg = parentOp.getWeightArgument(weightIndex); if (!weightArg) return false; bool found = false; parentOp.walk([&](mlir::Operation* op) { - if (auto mvmOp = mlir::dyn_cast(op)) - found |= mvmOp.getWeight() == *weightArg; - else if (auto vmmOp = mlir::dyn_cast(op)) + if (auto vmmOp = mlir::dyn_cast(op)) found |= vmmOp.getWeight() == *weightArg; }); return found; } -template -void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref callback) { +template +void walkVmmWeightUses(ParentOpTy parentOp, llvm::function_ref callback) { auto weights = parentOp.getWeights(); llvm::SmallSet visited; auto walkWeight = [&](mlir::Value weight) { @@ -49,7 +47,6 @@ void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref= computeOp.getWeights().size()) return false; - return hasMvmVmmWeightUse(computeOp, operandIndex); + return hasVmmWeightUse(computeOp, operandIndex); } bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) { diff --git a/src/PIM/Compiler/PimCompilerUtils.cpp b/src/PIM/Compiler/PimCompilerUtils.cpp index bc72bd6..d2e2ea0 100644 --- a/src/PIM/Compiler/PimCompilerUtils.cpp +++ b/src/PIM/Compiler/PimCompilerUtils.cpp @@ -30,20 +30,17 @@ void addPassesPim(OwningOpRef& module, if (pimEmissionTarget >= EmitSpatial) { pm.addPass(createONNXToSpatialPass()); pm.addPass(createMergeComputeNodesPass()); - // pm.addPass(createCountInstructionPass()); pm.addPass(createMessagePass("Onnx lowered to Spatial")); } if (pimEmissionTarget >= EmitPim) { pm.addPass(createSpatialToPimPass()); - // pm.addPass(createCountInstructionPass()); pm.addPass(createMessagePass("Spatial lowered to Pim")); } if (pimEmissionTarget >= EmitPimBufferized) { pm.addPass(createPimBufferizationPass()); pm.addPass(createPimStaticMemoryCoalescingPass()); - // pm.addPass(createCountInstructionPass()); pm.addPass(createMessagePass("Pim bufferized")); } @@ -54,7 +51,6 @@ void addPassesPim(OwningOpRef& module, pm.addPass(createPimVerificationPass()); pm.addPass(createMessagePass("Pim verified")); pm.addPass(createEmitPimCodePass()); - // pm.addPass(createCountInstructionPass()); pm.addPass(createMessagePass("Pim code emitted")); } } diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 343ca9e..1d8aa34 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -74,16 +74,6 @@ def PimHaltOp : PimOp<"halt", [Terminator]> { }]; } -def PimYieldOp : PimOp<"yield", [Terminator]> { - let summary = "Yield results from a Pim region"; - - let arguments = (ins - Variadic:$outputs - ); - - let hasCustomAssemblyFormat = 1; -} - //===----------------------------------------------------------------------===// // Communication //===----------------------------------------------------------------------===// diff --git a/src/PIM/Dialect/Pim/PimOpsAsm.cpp b/src/PIM/Dialect/Pim/PimOpsAsm.cpp index eedd8ef..c8779ec 100644 --- a/src/PIM/Dialect/Pim/PimOpsAsm.cpp +++ b/src/PIM/Dialect/Pim/PimOpsAsm.cpp @@ -227,40 +227,6 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) { return parser.parseRegion(*body, regionArgs); } -void PimYieldOp::print(OpAsmPrinter& printer) { - printer << " "; - printCompressedValueSequence(printer, getOutputs()); - printer.printOptionalAttrDict((*this)->getAttrs()); - printer << " : "; - printCompressedTypeSequence(printer, getOutputs().getTypes()); -} - -ParseResult PimYieldOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector outputs; - SmallVector outputTypes; - - OpAsmParser::UnresolvedOperand firstOutput; - OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput); - if (firstOutputResult.has_value()) { - if (failed(*firstOutputResult)) - return failure(); - if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs)) - return failure(); - while (succeeded(parser.parseOptionalComma())) - if (parseOneCompressedOperandEntry(parser, outputs)) - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) - return failure(); - - if (outputs.size() != outputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match"); - - return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands); -} - void PimConcatOp::print(OpAsmPrinter& printer) { printer << " axis " << getAxis() << " "; printCompressedValueSequence(printer, getInputs()); diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index f843879..e5381ee 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -2,7 +2,6 @@ add_onnx_mlir_dialect(Spatial spat) add_onnx_mlir_dialect_doc(spat Spatial.td) add_pim_library(SpatialOps - Channels.cpp SpatialOps.cpp SpatialOpsAsm.cpp SpatialOpsVerify.cpp diff --git a/src/PIM/Dialect/Spatial/Channels.cpp b/src/PIM/Dialect/Spatial/Channels.cpp deleted file mode 100644 index 7facf82..0000000 --- a/src/PIM/Dialect/Spatial/Channels.cpp +++ /dev/null @@ -1,178 +0,0 @@ -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/Matchers.h" - -#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp" - -using namespace mlir; - -namespace onnx_mlir::spatial { - -namespace { - -static FailureOr getConstantI64(Value value) { - APInt constantValue; - if (!matchPattern(value, m_ConstantInt(&constantValue))) - return failure(); - return constantValue.getSExtValue(); -} - -static FailureOr getConstantI32(Value value) { - APInt constantValue; - if (!matchPattern(value, m_ConstantInt(&constantValue))) - return failure(); - return static_cast(constantValue.getSExtValue()); -} - -static FailureOr getChannelId(SpatChannelSendOp sendOp) { - return getConstantI64(sendOp.getChannelId()); -} - -static FailureOr getChannelId(SpatChannelReceiveOp receiveOp) { - return getConstantI64(receiveOp.getChannelId()); -} - -static FailureOr getSourceCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getSourceCoreId()); } - -static FailureOr getSourceCoreId(SpatChannelReceiveOp receiveOp) { - return getConstantI32(receiveOp.getSourceCoreId()); -} - -static FailureOr getTargetCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getTargetCoreId()); } - -static FailureOr getTargetCoreId(SpatChannelReceiveOp receiveOp) { - return getConstantI32(receiveOp.getTargetCoreId()); -} - -static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) { - if (!endpoints.send || !endpoints.receive) - return failure(); - - FailureOr sendSourceCoreId = getSourceCoreId(endpoints.send); - FailureOr receiveSourceCoreId = getSourceCoreId(endpoints.receive); - if (failed(sendSourceCoreId) || failed(receiveSourceCoreId)) { - endpoints.send.emitOpError("channel endpoints must use constant sourceCoreId operands"); - return failure(); - } - if (*sendSourceCoreId != *receiveSourceCoreId) { - endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive"); - return failure(); - } - - FailureOr sendTargetCoreId = getTargetCoreId(endpoints.send); - FailureOr receiveTargetCoreId = getTargetCoreId(endpoints.receive); - if (failed(sendTargetCoreId) || failed(receiveTargetCoreId)) { - endpoints.send.emitOpError("channel endpoints must use constant targetCoreId operands"); - return failure(); - } - if (*sendTargetCoreId != *receiveTargetCoreId) { - endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive"); - return failure(); - } - if (endpoints.send.getInput().getType() != endpoints.receive.getOutput().getType()) { - endpoints.send.emitOpError("input type does not match paired spat.channel_receive result type"); - return failure(); - } - - return success(); -} - -} // namespace - -Channels::Channels(func::FuncOp funcOp) { - if (!funcOp) - return; - - funcOp.walk([&](SpatChannelSendOp sendOp) { insertSend(sendOp); }); - funcOp.walk([&](SpatChannelReceiveOp receiveOp) { insertReceive(receiveOp); }); -} - -Channels::ChannelId Channels::allocate() { return nextChannelId++; } - -void Channels::insertSend(SpatChannelSendOp sendOp) { - FailureOr channelId = getChannelId(sendOp); - if (failed(channelId)) - return; - nextChannelId = std::max(nextChannelId, *channelId + 1); - endpoints[*channelId].send = sendOp; -} - -void Channels::insertReceive(SpatChannelReceiveOp receiveOp) { - FailureOr channelId = getChannelId(receiveOp); - if (failed(channelId)) - return; - nextChannelId = std::max(nextChannelId, *channelId + 1); - endpoints[*channelId].receive = receiveOp; -} - -void Channels::eraseSend(SpatChannelSendOp sendOp) { - FailureOr channelId = getChannelId(sendOp); - if (failed(channelId)) - return; - auto it = endpoints.find(*channelId); - if (it == endpoints.end()) - return; - it->second.send = {}; - if (!it->second.receive) - endpoints.erase(it); -} - -void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) { - FailureOr channelId = getChannelId(receiveOp); - if (failed(channelId)) - return; - auto it = endpoints.find(*channelId); - if (it == endpoints.end()) - return; - it->second.receive = {}; - if (!it->second.send) - endpoints.erase(it); -} - -FailureOr Channels::lookup(ChannelId id) const { - auto it = endpoints.find(id); - if (it == endpoints.end()) - return failure(); - return it->second; -} - -FailureOr Channels::getReceiveFor(SpatChannelSendOp sendOp) const { - FailureOr channelId = getChannelId(sendOp); - if (failed(channelId)) - return failure(); - auto endpointsOr = lookup(*channelId); - if (failed(endpointsOr) || !endpointsOr->receive) - return failure(); - return endpointsOr->receive; -} - -FailureOr Channels::getSendFor(SpatChannelReceiveOp receiveOp) const { - FailureOr channelId = getChannelId(receiveOp); - if (failed(channelId)) - return failure(); - auto endpointsOr = lookup(*channelId); - if (failed(endpointsOr) || !endpointsOr->send) - return failure(); - return endpointsOr->send; -} - -LogicalResult Channels::verify() const { - for (const auto& [channelId, pair] : endpoints) { - if (!pair.send || !pair.receive) { - if (pair.send) { - auto sendOp = pair.send; - sendOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_receive"; - } - else if (pair.receive) { - auto receiveOp = pair.receive; - receiveOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_send"; - } - return failure(); - } - if (failed(verifyEndpointPair(pair))) - return failure(); - } - return success(); -} - -} // namespace onnx_mlir::spatial diff --git a/src/PIM/Dialect/Spatial/Channels.hpp b/src/PIM/Dialect/Spatial/Channels.hpp deleted file mode 100644 index 5f99569..0000000 --- a/src/PIM/Dialect/Spatial/Channels.hpp +++ /dev/null @@ -1,43 +0,0 @@ -#pragma once - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Support/LogicalResult.h" - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/StringRef.h" - -#include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" - -namespace onnx_mlir::spatial { - -struct ChannelEndpoints { - SpatChannelSendOp send; - SpatChannelReceiveOp receive; -}; - -class Channels { -public: - using ChannelId = int64_t; - - explicit Channels(mlir::func::FuncOp funcOp); - - ChannelId allocate(); - - void insertSend(SpatChannelSendOp sendOp); - void insertReceive(SpatChannelReceiveOp receiveOp); - void eraseSend(SpatChannelSendOp sendOp); - void eraseReceive(SpatChannelReceiveOp receiveOp); - - llvm::FailureOr lookup(ChannelId id) const; - llvm::FailureOr getReceiveFor(SpatChannelSendOp sendOp) const; - llvm::FailureOr getSendFor(SpatChannelReceiveOp receiveOp) const; - - mlir::LogicalResult verify() const; - -private: - ChannelId nextChannelId = 0; - llvm::DenseMap endpoints; -}; - -} // namespace onnx_mlir::spatial diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 38d0a9f..3489ddf 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -217,25 +217,6 @@ def SpatVMMOp : SpatOp<"wvmm", []> { }]; } -def SpatMVMOp : SpatOp<"Wmvm", []> { - let summary = "Matrix-vector multiplication within a weighted compute operation"; - - let arguments = (ins - SpatTensor:$weight, - SpatTensor:$input - ); - - let results = (outs - SpatTensor:$output - ); - - let hasVerifier = 1; - - let assemblyFormat = [{ - `[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output) - }]; -} - def SpatVAddOp : SpatOp<"vadd", []> { let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1"; @@ -272,22 +253,6 @@ def SpatVMulOp : SpatOp<"vmul", []> { }]; } -def SpatSumOp : SpatOp<"sum", []> { - let summary = "Reduce all elements of the input tensor to a single scalar wrapped in a tensor"; - - let arguments = (ins - SpatTensor:$input - ); - - let results = (outs - SpatTensor:$output - ); - - let assemblyFormat = [{ - `(` $input `)` attr-dict `:` type($input) `->` type($output) - }]; -} - def SpatVAvgOp : SpatOp<"vavg", []> { let summary = "Average all elements of the input tensor to a single scalar wrapped in a tensor"; diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index 9b315fa..364832d 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -26,68 +26,6 @@ namespace spatial { namespace { -inline LogicalResult mvmOpVerifySize2(SpatMVMOp* emitter, - ArrayRef& matrixShape, - ArrayRef& vectorShape, - ArrayRef& outputShape) { - if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) - return emitter->emitError("matrix, vector and output must have rank 2"); - - int64_t N = matrixShape[0]; - int64_t M = matrixShape[1]; - if (N <= 0 || M <= 0) - return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0"); - - int64_t vectorM = vectorShape[0]; - int64_t vector1 = vectorShape[1]; - if (vectorM != M || vector1 != 1) - return emitter->emitError("vector shape must be (M, 1)"); - - int64_t outputN = outputShape[0]; - int64_t output1 = outputShape[1]; - if (outputN != N || output1 != 1) - return emitter->emitError("output shape must be (N, 1)"); - - return success(); -} - -inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter, - ArrayRef& matrixShape, - ArrayRef& vectorShape, - ArrayRef& outputShape) { - if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4) - return emitter->emitError("matrix, vector and output must have rank 4"); - - int64_t N = matrixShape[0]; - int64_t M = matrixShape[1]; - int64_t matrix1First = matrixShape[2]; - int64_t matrix1Second = matrixShape[3]; - if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1) - return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0"); - - int64_t vector1First = vectorShape[0]; - int64_t vectorM = vectorShape[1]; - int64_t vector1Second = vectorShape[2]; - int64_t vector1Third = vectorShape[3]; - if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) { - if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) { - // This is ok, it was caused by the simplification of the concat error. - } - else { - return emitter->emitError("vector shape must be (1, M, 1, 1)"); - } - } - - int64_t output1First = outputShape[0]; - int64_t outputN = outputShape[1]; - int64_t output1Second = outputShape[2]; - int64_t output1Third = outputShape[3]; - if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1) - return emitter->emitError("output shape must be (1, N, 1, 1)"); - - return success(); -} - static FailureOr> getWeightShapeForWeightedOp(Value weight) { auto shapedType = dyn_cast(weight.getType()); if (!shapedType) @@ -287,21 +225,6 @@ static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) { } // namespace -LogicalResult SpatMVMOp::verify() { - auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight()); - if (failed(matrixShapeOpt)) - return emitError("weight must be a shaped value"); - auto matrixShape = *matrixShapeOpt; - auto vectorShape = getInput().getType().getShape(); - auto outputShape = getOutput().getType().getShape(); - - if (matrixShape.size() == 2) - return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape); - if (matrixShape.size() == 4) - return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape); - return emitError("matrix rank must be 2 or 4"); -} - LogicalResult SpatVMMOp::verify() { auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight()); if (failed(matrixShapeOpt)) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 4552762..ec97840 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -136,7 +136,6 @@ static std::optional getComputeCoreId(SpatCompute compute) { struct ComputeMotifInfo { uint64_t instructionCount = 0; - uint64_t weightedMvmCount = 0; uint64_t weightedVmmCount = 0; }; @@ -285,8 +284,6 @@ void emitMotifProfile(func::FuncOp funcOp) { ComputeMotifInfo& info = computeInfos[index]; for (Operation& op : compute.getBody().front()) { info.instructionCount++; - if (isa(&op)) - info.weightedMvmCount++; if (isa(&op)) info.weightedVmmCount++; } @@ -400,7 +397,7 @@ void emitMotifProfile(func::FuncOp funcOp) { wideWeightedVmmLevels256 += count >= 256; } - using ShapeKey = std::tuple; + using ShapeKey = std::tuple; SmallVector weightedVmmShapeKeys; for (auto [index, compute] : llvm::enumerate(computes)) { const ComputeMotifInfo& info = computeInfos[index]; @@ -408,7 +405,6 @@ void emitMotifProfile(func::FuncOp funcOp) { continue; weightedVmmShapeKeys.push_back({info.instructionCount, info.weightedVmmCount, - info.weightedMvmCount, static_cast(compute.getWeights().size()), static_cast(compute.getInputs().size()), static_cast(parents[index].size()), @@ -461,14 +457,13 @@ void emitMotifProfile(func::FuncOp funcOp) { for (size_t rank = 0, end = std::min(weightedVmmShapeCounts.size(), 5); rank < end; ++rank) { auto [count, shape] = weightedVmmShapeCounts[rank]; - auto [insts, vmmOps, mvmOps, weights, inputs, fanIn, fanOut] = shape; + auto [insts, vmmOps, weights, inputs, fanIn, fanOut] = shape; llvm::errs() << llvm::formatv("[DCP-MOTIF] wvmmShape rank={0} count={1} insts={2} vmmOps={3} " - "mvmOps={4} weights={5} inputs={6} fanIn={7} fanOut={8}\n", + "weights={4} inputs={5} fanIn={6} fanOut={7}\n", rank, count, insts, vmmOps, - mvmOps, weights, inputs, fanIn, diff --git a/src/PIM/Pass/CMakeLists.txt b/src/PIM/Pass/CMakeLists.txt index 95edf64..07080c8 100644 --- a/src/PIM/Pass/CMakeLists.txt +++ b/src/PIM/Pass/CMakeLists.txt @@ -1,5 +1,4 @@ add_pim_library(OMPimPasses - CountInstructionPass.cpp MessagePass.cpp PimCodegen/HostConstantFolding/Common.cpp PimCodegen/HostConstantFolding/Patterns/Constant.cpp diff --git a/src/PIM/Pass/CountInstructionPass.cpp b/src/PIM/Pass/CountInstructionPass.cpp deleted file mode 100644 index 2a8ad9e..0000000 --- a/src/PIM/Pass/CountInstructionPass.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Pass/Pass.h" - -#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" - -using namespace mlir; - -namespace onnx_mlir { - -namespace { - -struct CountInstructionPass : public PassWrapper> { - - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CountInstructionPass) - - StringRef getArgument() const override { return "count-instruction-pass"; } - - StringRef getDescription() const override { return "Count instructions for each core/compute in the module"; } - - // Make sure that we have a valid default constructor and copy - // constructor to make sure that the options are initialized properly. - CountInstructionPass() {} - CountInstructionPass(const CountInstructionPass& pass) - : PassWrapper>() {} - void runOnOperation() final { - ModuleOp module = getOperation(); - - func::FuncOp func = *module.getOps().begin(); - - unsigned totalInstructionCount = 0; - - unsigned computeId = 0; - for (auto computeOp : func.getOps()) { - unsigned instructionCount = 0; - instructionCount += computeOp.getBody().front().getOperations().size(); - llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n"; - totalInstructionCount += instructionCount; - computeId++; - } - - unsigned coreId = 0; - for (auto coreOp : func.getOps()) { - unsigned instructionCount = 0; - instructionCount += coreOp.getBody().front().getOperations().size(); - llvm::outs() << "Core " << coreId << ": " << instructionCount << " instructions\n"; - totalInstructionCount += instructionCount; - coreId++; - } - - llvm::outs() << "Total instruction count: " << totalInstructionCount << "\n"; - } -}; - -} // namespace - -std::unique_ptr createCountInstructionPass() { return std::make_unique(); } - -} // namespace onnx_mlir diff --git a/src/PIM/Pass/PIMPasses.h b/src/PIM/Pass/PIMPasses.h index 9f24bd1..86a1e1e 100644 --- a/src/PIM/Pass/PIMPasses.h +++ b/src/PIM/Pass/PIMPasses.h @@ -29,6 +29,4 @@ std::unique_ptr createEmitPimCodePass(); std::unique_ptr createMessagePass(std::string message); -std::unique_ptr createCountInstructionPass(); - } // namespace onnx_mlir