This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
add_pim_library(OMPimCommon
|
||||
IR/AddressAnalysis.cpp
|
||||
IR/ConstantUtils.cpp
|
||||
IR/CoreBlockUtils.cpp
|
||||
IR/EntryPointUtils.cpp
|
||||
IR/ShapeUtils.cpp
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
@@ -55,6 +57,47 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
||||
|
||||
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return mlir::failure();
|
||||
|
||||
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
|
||||
if (!denseAttr || !globalType || !globalType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
auto elementType = denseAttr.getElementType();
|
||||
if (!elementType.isIndex() && !elementType.isInteger())
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> indices;
|
||||
indices.reserve(loadOp.getIndices().size());
|
||||
for (mlir::Value index : loadOp.getIndices()) {
|
||||
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
|
||||
if (failed(resolvedIndex))
|
||||
return mlir::failure();
|
||||
indices.push_back(*resolvedIndex);
|
||||
}
|
||||
|
||||
if (indices.size() != static_cast<size_t>(globalType.getRank()))
|
||||
return mlir::failure();
|
||||
|
||||
auto strides = computeRowMajorStrides(globalType.getShape());
|
||||
int64_t linearIndex = linearizeIndex(indices, strides);
|
||||
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
|
||||
return mlir::failure();
|
||||
|
||||
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
@@ -126,6 +169,9 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
|
||||
return resolveConstantGlobalLoad(loadOp, knowledge);
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Block* getHostConstantBlock(Operation* anchorOp) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
|
||||
for (Operation* current = anchorOp; current; current = current->getParentOp())
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
|
||||
return current->getBlock();
|
||||
|
||||
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
|
||||
return &funcOp.getBody().front();
|
||||
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
|
||||
return moduleOp.getBody();
|
||||
return anchorOp->getBlock();
|
||||
}
|
||||
|
||||
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, OperationFolder& folder) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
Block* hostBlock = getHostConstantBlock(anchorOp);
|
||||
for (Operation& op : *hostBlock) {
|
||||
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
||||
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
||||
continue;
|
||||
return constantOp.getResult();
|
||||
}
|
||||
|
||||
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
|
||||
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
||||
}
|
||||
|
||||
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
|
||||
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
|
||||
}
|
||||
|
||||
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
|
||||
}
|
||||
|
||||
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
|
||||
}
|
||||
|
||||
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp);
|
||||
|
||||
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
|
||||
mlir::Attribute value,
|
||||
mlir::Type type,
|
||||
mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
@@ -30,6 +31,9 @@ walkPimCoreBlock(mlir::Block& block,
|
||||
for (mlir::Operation& op : block) {
|
||||
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||
continue;
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
|
||||
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
|
||||
continue;
|
||||
|
||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
|
||||
mlir::Block& loopBody = forOp.getRegion().front();
|
||||
|
||||
@@ -21,12 +21,13 @@ namespace {
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
mlir::Value weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
bool found = false;
|
||||
parentOp.walk([&](mlir::Operation* op) {
|
||||
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
||||
found |= mvmOp.getWeightIndex() == weightIndex;
|
||||
found |= mvmOp.getWeight() == weightArg;
|
||||
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||
found |= vmmOp.getWeightIndex() == weightIndex;
|
||||
found |= vmmOp.getWeight() == weightArg;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
@@ -35,13 +36,18 @@ template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
auto weights = parentOp.getWeights();
|
||||
llvm::SmallSet<unsigned, 8> visited;
|
||||
auto walkWeightIndex = [&](unsigned weightIndex) {
|
||||
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
auto walkWeight = [&](mlir::Value weight) {
|
||||
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
|
||||
if (parentOp.getWeightArgument(weightIndex) != weight)
|
||||
continue;
|
||||
if (visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -90,18 +96,21 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) {
|
||||
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
auto weights = coreOp.getWeights();
|
||||
unsigned weightIndex = vmmOp.getWeightIndex();
|
||||
if (weightIndex < weights.size())
|
||||
callback(coreOp->getOpOperand(weightIndex));
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||
callback(coreOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
});
|
||||
});
|
||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto weights = coreBatchOp.getWeights();
|
||||
for (auto weight : weights)
|
||||
for (mlir::OpOperand& use : weight.getUses())
|
||||
if (use.getOwner() == coreBatchOp.getOperation())
|
||||
callback(use);
|
||||
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
|
||||
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||
callback(coreBatchOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
Reference in New Issue
Block a user