fix weightAlways attribute in spatial
This commit is contained in:
@@ -4,6 +4,7 @@
|
|||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallSet.h"
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
@@ -96,6 +97,53 @@ void markWeightAlways(Operation* op) {
|
|||||||
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||||
|
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||||
|
bool found = false;
|
||||||
|
parentOp.walk([&](Operation* op) {
|
||||||
|
if (auto mvmOp = dyn_cast<MVMOpTy>(op))
|
||||||
|
found |= mvmOp.getWeightIndex() == weightIndex;
|
||||||
|
else if (auto vmmOp = dyn_cast<VMMOpTy>(op))
|
||||||
|
found |= vmmOp.getWeightIndex() == weightIndex;
|
||||||
|
});
|
||||||
|
return found;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||||
|
void walkMvmVmmWeightUses(ParentOpTy parentOp, function_ref<void(OpOperand&)> callback) {
|
||||||
|
auto weights = parentOp.getWeights();
|
||||||
|
llvm::SmallSet<unsigned, 8> visited;
|
||||||
|
auto walkWeightIndex = [&](unsigned weightIndex) {
|
||||||
|
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
|
||||||
|
callback(parentOp->getOpOperand(weightIndex));
|
||||||
|
};
|
||||||
|
|
||||||
|
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||||
|
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool isSpatialMvmVmmWeightUse(OpOperand& use) {
|
||||||
|
Operation* user = use.getOwner();
|
||||||
|
unsigned operandIndex = use.getOperandNumber();
|
||||||
|
|
||||||
|
auto computeOp = dyn_cast<spatial::SpatCompute>(user);
|
||||||
|
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return hasMvmVmmWeightUse<spatial::SpatWeightedMVMOp, spatial::SpatWeightedVMMOp>(computeOp, operandIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
void walkPimMvmVmmWeightUses(Operation* root, function_ref<void(OpOperand&)> callback) {
|
||||||
|
assert(root && "expected valid root op");
|
||||||
|
root->walk([&](pim::PimCoreOp coreOp) {
|
||||||
|
walkMvmVmmWeightUses<pim::PimMVMOp, pim::PimVMMOp>(coreOp, callback);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
||||||
if (!moduleOp || !getGlobalOp)
|
if (!moduleOp || !getGlobalOp)
|
||||||
return {};
|
return {};
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
@@ -40,6 +41,10 @@ bool hasWeightAlways(mlir::Operation* op);
|
|||||||
|
|
||||||
void markWeightAlways(mlir::Operation* op);
|
void markWeightAlways(mlir::Operation* op);
|
||||||
|
|
||||||
|
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
|
||||||
|
|
||||||
|
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||||
|
|
||||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||||
|
|
||||||
llvm::FailureOr<mlir::Operation*>
|
llvm::FailureOr<mlir::Operation*>
|
||||||
|
|||||||
@@ -392,7 +392,9 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|||||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||||
bool isAlwaysWeight =
|
bool isAlwaysWeight =
|
||||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); });
|
!constantOp->use_empty() && llvm::all_of(constantOp->getUses(), [](OpOperand& use) -> bool {
|
||||||
|
return isSpatialMvmVmmWeightUse(use);
|
||||||
|
});
|
||||||
if (isAlwaysWeight)
|
if (isAlwaysWeight)
|
||||||
markWeightAlways(constantOp);
|
markWeightAlways(constantOp);
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -94,10 +94,8 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
|
|
||||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||||
funcOp.walk([&](PimCoreOp coreOp) {
|
funcOp.walk([&](PimCoreOp coreOp) {
|
||||||
auto annotateWeight = [&](unsigned weightIndex) {
|
walkPimMvmVmmWeightUses(coreOp, [&](OpOperand& weightUse) {
|
||||||
if (weightIndex >= coreOp.getWeights().size())
|
Value weight = weightUse.get();
|
||||||
return;
|
|
||||||
Value weight = coreOp.getWeights()[weightIndex];
|
|
||||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||||
if (!getGlobalOp)
|
if (!getGlobalOp)
|
||||||
return;
|
return;
|
||||||
@@ -105,10 +103,7 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO
|
|||||||
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
||||||
markWeightAlways(getGlobalOp);
|
markWeightAlways(getGlobalOp);
|
||||||
markWeightAlways(globalMemrefOp);
|
markWeightAlways(globalMemrefOp);
|
||||||
};
|
});
|
||||||
|
|
||||||
coreOp.walk([&](PimMVMOp mvmOp) { annotateWeight(mvmOp.getWeightIndex()); });
|
|
||||||
coreOp.walk([&](PimVMMOp vmmOp) { annotateWeight(vmmOp.getWeightIndex()); });
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
add_custom_target(pim-unittest)
|
add_custom_target(pim-unittest)
|
||||||
set_target_properties(pim-unittest PROPERTIES FOLDER "Tests")
|
set_target_properties(pim-unittest PROPERTIES FOLDER "Tests")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user