diff --git a/src/PIM/Common/PimCommon.cpp b/src/PIM/Common/PimCommon.cpp index 0f72d1f..3a9d382 100644 --- a/src/PIM/Common/PimCommon.cpp +++ b/src/PIM/Common/PimCommon.cpp @@ -4,6 +4,7 @@ #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/Support/raw_os_ostream.h" #include @@ -96,6 +97,53 @@ void markWeightAlways(Operation* op) { op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext())); } +namespace { + +template +bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { + bool found = false; + parentOp.walk([&](Operation* op) { + if (auto mvmOp = dyn_cast(op)) + found |= mvmOp.getWeightIndex() == weightIndex; + else if (auto vmmOp = dyn_cast(op)) + found |= vmmOp.getWeightIndex() == weightIndex; + }); + return found; +} + +template +void walkMvmVmmWeightUses(ParentOpTy parentOp, function_ref callback) { + auto weights = parentOp.getWeights(); + llvm::SmallSet visited; + auto walkWeightIndex = [&](unsigned weightIndex) { + if (weightIndex < weights.size() && visited.insert(weightIndex).second) + callback(parentOp->getOpOperand(weightIndex)); + }; + + parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); }); + parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); }); +} + +} // namespace + +bool isSpatialMvmVmmWeightUse(OpOperand& use) { + Operation* user = use.getOwner(); + unsigned operandIndex = use.getOperandNumber(); + + auto computeOp = dyn_cast(user); + if (!computeOp || operandIndex >= computeOp.getWeights().size()) + return false; + + return hasMvmVmmWeightUse(computeOp, operandIndex); +} + +void walkPimMvmVmmWeightUses(Operation* root, function_ref callback) { + assert(root && "expected valid root op"); + root->walk([&](pim::PimCoreOp coreOp) { + walkMvmVmmWeightUses(coreOp, callback); + }); +} + memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) { if (!moduleOp || !getGlobalOp) return {}; diff --git a/src/PIM/Common/PimCommon.hpp b/src/PIM/Common/PimCommon.hpp index 99cec4d..7b42d35 100644 --- a/src/PIM/Common/PimCommon.hpp +++ b/src/PIM/Common/PimCommon.hpp @@ -7,6 +7,7 @@ #include "mlir/IR/Value.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -40,6 +41,10 @@ bool hasWeightAlways(mlir::Operation* op); void markWeightAlways(mlir::Operation* op); +bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use); + +void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref callback); + mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp); llvm::FailureOr diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 7909b82..71b7619 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -392,7 +392,9 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) { void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { funcOp.walk([&](arith::ConstantOp constantOp) { bool isAlwaysWeight = - llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa(user); }); + !constantOp->use_empty() && llvm::all_of(constantOp->getUses(), [](OpOperand& use) -> bool { + return isSpatialMvmVmmWeightUse(use); + }); if (isAlwaysWeight) markWeightAlways(constantOp); }); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index 69da36a..0cd8482 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -94,10 +94,8 @@ void PimBufferizationPass::runOnOperation() { void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { funcOp.walk([&](PimCoreOp coreOp) { - auto annotateWeight = [&](unsigned weightIndex) { - if (weightIndex >= coreOp.getWeights().size()) - return; - Value weight = coreOp.getWeights()[weightIndex]; + walkPimMvmVmmWeightUses(coreOp, [&](OpOperand& weightUse) { + Value weight = weightUse.get(); auto getGlobalOp = weight.getDefiningOp(); if (!getGlobalOp) return; @@ -105,10 +103,7 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO assert("Weights must be constants" && globalMemrefOp.getConstant()); markWeightAlways(getGlobalOp); markWeightAlways(globalMemrefOp); - }; - - coreOp.walk([&](PimMVMOp mvmOp) { annotateWeight(mvmOp.getWeightIndex()); }); - coreOp.walk([&](PimVMMOp vmmOp) { annotateWeight(vmmOp.getWeightIndex()); }); + }); }); } diff --git a/test/PIM/CMakeLists.txt b/test/PIM/CMakeLists.txt index 9fb2355..b7735ec 100644 --- a/test/PIM/CMakeLists.txt +++ b/test/PIM/CMakeLists.txt @@ -1,5 +1,3 @@ -# SPDX-License-Identifier: Apache-2.0 - add_custom_target(pim-unittest) set_target_properties(pim-unittest PROPERTIES FOLDER "Tests")