fix weightAlways attribute in spatial
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <filesystem>
|
||||
@@ -96,6 +97,53 @@ void markWeightAlways(Operation* op) {
|
||||
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
bool found = false;
|
||||
parentOp.walk([&](Operation* op) {
|
||||
if (auto mvmOp = dyn_cast<MVMOpTy>(op))
|
||||
found |= mvmOp.getWeightIndex() == weightIndex;
|
||||
else if (auto vmmOp = dyn_cast<VMMOpTy>(op))
|
||||
found |= vmmOp.getWeightIndex() == weightIndex;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
void walkMvmVmmWeightUses(ParentOpTy parentOp, function_ref<void(OpOperand&)> callback) {
|
||||
auto weights = parentOp.getWeights();
|
||||
llvm::SmallSet<unsigned, 8> visited;
|
||||
auto walkWeightIndex = [&](unsigned weightIndex) {
|
||||
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
};
|
||||
|
||||
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool isSpatialMvmVmmWeightUse(OpOperand& use) {
|
||||
Operation* user = use.getOwner();
|
||||
unsigned operandIndex = use.getOperandNumber();
|
||||
|
||||
auto computeOp = dyn_cast<spatial::SpatCompute>(user);
|
||||
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
||||
return false;
|
||||
|
||||
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
|
||||
}
|
||||
|
||||
void walkPimMvmVmmWeightUses(Operation* root, function_ref<void(OpOperand&)> callback) {
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) {
|
||||
walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback);
|
||||
});
|
||||
}
|
||||
|
||||
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
||||
if (!moduleOp || !getGlobalOp)
|
||||
return {};
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
@@ -40,6 +41,10 @@ bool hasWeightAlways(mlir::Operation* op);
|
||||
|
||||
void markWeightAlways(mlir::Operation* op);
|
||||
|
||||
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
|
||||
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||
|
||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||
|
||||
llvm::FailureOr<mlir::Operation*>
|
||||
|
||||
@@ -392,7 +392,9 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
bool isAlwaysWeight =
|
||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); });
|
||||
!constantOp->use_empty() && llvm::all_of(constantOp->getUses(), [](OpOperand& use) -> bool {
|
||||
return isSpatialMvmVmmWeightUse(use);
|
||||
});
|
||||
if (isAlwaysWeight)
|
||||
markWeightAlways(constantOp);
|
||||
});
|
||||
|
||||
@@ -94,10 +94,8 @@ void PimBufferizationPass::runOnOperation() {
|
||||
|
||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](PimCoreOp coreOp) {
|
||||
auto annotateWeight = [&](unsigned weightIndex) {
|
||||
if (weightIndex >= coreOp.getWeights().size())
|
||||
return;
|
||||
Value weight = coreOp.getWeights()[weightIndex];
|
||||
walkPimMvmVmmWeightUses(coreOp, [&](OpOperand& weightUse) {
|
||||
Value weight = weightUse.get();
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return;
|
||||
@@ -105,10 +103,7 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
|
||||
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
||||
markWeightAlways(getGlobalOp);
|
||||
markWeightAlways(globalMemrefOp);
|
||||
};
|
||||
|
||||
coreOp.walk([&](PimMVMOp mvmOp) { annotateWeight(mvmOp.getWeightIndex()); });
|
||||
coreOp.walk([&](PimVMMOp vmmOp) { annotateWeight(vmmOp.getWeightIndex()); });
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
add_custom_target(pim-unittest)
|
||||
set_target_properties(pim-unittest PROPERTIES FOLDER "Tests")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user