fix weightAlways attribute in spatial

This commit is contained in:
NiccoloN
2026-04-23 10:04:47 +02:00
parent 412ca957f6
commit 89b3501aa8
5 changed files with 59 additions and 11 deletions

View File

@@ -4,6 +4,7 @@
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/raw_os_ostream.h" #include "llvm/Support/raw_os_ostream.h"
#include <filesystem> #include <filesystem>
@@ -96,6 +97,53 @@ void markWeightAlways(Operation* op) {
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext())); op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
} }
namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
bool found = false;
parentOp.walk([&](Operation* op) {
if (auto mvmOp = dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeightIndex() == weightIndex;
else if (auto vmmOp = dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeightIndex() == weightIndex;
});
return found;
}
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, function_ref<void(OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeightIndex = [&](unsigned weightIndex) {
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
};
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
}
} // namespace
bool isSpatialMvmVmmWeightUse(OpOperand& use) {
Operation* user = use.getOwner();
unsigned operandIndex = use.getOperandNumber();
auto computeOp = dyn_cast<spatial::SpatCompute>(user);
if (!computeOp || operandIndex >= computeOp.getWeights().size())
return false;
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
}
void walkPimMvmVmmWeightUses(Operation* root, function_ref<void(OpOperand&)> callback) {
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) {
walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback);
});
}
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) { memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp) if (!moduleOp || !getGlobalOp)
return {}; return {};

View File

@@ -7,6 +7,7 @@
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
@@ -40,6 +41,10 @@ bool hasWeightAlways(mlir::Operation* op);
void markWeightAlways(mlir::Operation* op); void markWeightAlways(mlir::Operation* op);
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp); mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
llvm::FailureOr<mlir::Operation*> llvm::FailureOr<mlir::Operation*>

View File

@@ -392,7 +392,9 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) { funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight = bool isAlwaysWeight =
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); }); !constantOp->use_empty() && llvm::all_of(constantOp->getUses(), [](OpOperand& use) -> bool {
return isSpatialMvmVmmWeightUse(use);
});
if (isAlwaysWeight) if (isAlwaysWeight)
markWeightAlways(constantOp); markWeightAlways(constantOp);
}); });

View File

@@ -94,10 +94,8 @@ void PimBufferizationPass::runOnOperation() {
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
funcOp.walk([&](PimCoreOp coreOp) { funcOp.walk([&](PimCoreOp coreOp) {
auto annotateWeight = [&](unsigned weightIndex) { walkPimMvmVmmWeightUses(coreOp, [&](OpOperand& weightUse) {
if (weightIndex >= coreOp.getWeights().size()) Value weight = weightUse.get();
return;
Value weight = coreOp.getWeights()[weightIndex];
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>(); auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) if (!getGlobalOp)
return; return;
@@ -105,10 +103,7 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
assert("Weights must be constants" && globalMemrefOp.getConstant()); assert("Weights must be constants" && globalMemrefOp.getConstant());
markWeightAlways(getGlobalOp); markWeightAlways(getGlobalOp);
markWeightAlways(globalMemrefOp); markWeightAlways(globalMemrefOp);
}; });
coreOp.walk([&](PimMVMOp mvmOp) { annotateWeight(mvmOp.getWeightIndex()); });
coreOp.walk([&](PimVMMOp vmmOp) { annotateWeight(vmmOp.getWeightIndex()); });
}); });
} }

View File

@@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
add_custom_target(pim-unittest) add_custom_target(pim-unittest)
set_target_properties(pim-unittest PROPERTIES FOLDER "Tests") set_target_properties(pim-unittest PROPERTIES FOLDER "Tests")