This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
@@ -55,6 +57,47 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
||||
|
||||
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return mlir::failure();
|
||||
|
||||
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
|
||||
if (!denseAttr || !globalType || !globalType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
auto elementType = denseAttr.getElementType();
|
||||
if (!elementType.isIndex() && !elementType.isInteger())
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> indices;
|
||||
indices.reserve(loadOp.getIndices().size());
|
||||
for (mlir::Value index : loadOp.getIndices()) {
|
||||
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
|
||||
if (failed(resolvedIndex))
|
||||
return mlir::failure();
|
||||
indices.push_back(*resolvedIndex);
|
||||
}
|
||||
|
||||
if (indices.size() != static_cast<size_t>(globalType.getRank()))
|
||||
return mlir::failure();
|
||||
|
||||
auto strides = computeRowMajorStrides(globalType.getShape());
|
||||
int64_t linearIndex = linearizeIndex(indices, strides);
|
||||
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
|
||||
return mlir::failure();
|
||||
|
||||
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
@@ -126,6 +169,9 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
|
||||
return resolveConstantGlobalLoad(loadOp, knowledge);
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user