808 lines
33 KiB
C++
808 lines
33 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
|
|
|
#include <limits>
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
|
|
namespace onnx_mlir {
|
|
|
|
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) {
|
|
if (!moduleOp || !getGlobalOp)
|
|
return {};
|
|
return moduleOp.lookupSymbol<mlir::memref::GlobalOp>(getGlobalOp.getName());
|
|
}
|
|
|
|
namespace {
|
|
|
|
mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
|
if (!knowledge)
|
|
return value;
|
|
|
|
auto iter = knowledge->aliases.find(value);
|
|
while (iter != knowledge->aliases.end()) {
|
|
value = iter->second;
|
|
iter = knowledge->aliases.find(value);
|
|
}
|
|
return value;
|
|
}
|
|
|
|
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value);
|
|
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value);
|
|
|
|
template <typename... Args>
|
|
CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) {
|
|
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...));
|
|
}
|
|
|
|
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
|
value = resolveAlias(value, knowledge);
|
|
|
|
if (mlir::isa<mlir::BlockArgument>(value))
|
|
return value;
|
|
|
|
mlir::Operation* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return value;
|
|
|
|
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
|
if (auto result = mlir::dyn_cast<mlir::OpResult>(value))
|
|
if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
|
|
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
|
}
|
|
|
|
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
|
|
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
|
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
|
|
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
|
|
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp))
|
|
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
|
|
|
|
return value;
|
|
}
|
|
|
|
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
|
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
|
|
|
static llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefStrides(mlir::MemRefType type) {
|
|
llvm::SmallVector<int64_t> strides;
|
|
int64_t offset = 0;
|
|
if (failed(type.getStridesAndOffset(strides, offset)))
|
|
return mlir::failure();
|
|
if (llvm::any_of(strides, mlir::ShapedType::isDynamic))
|
|
return mlir::failure();
|
|
return strides;
|
|
}
|
|
|
|
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
|
|
const StaticValueKnowledge* knowledge) {
|
|
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
|
if (!getGlobalOp)
|
|
return mlir::failure();
|
|
|
|
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
|
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
|
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
|
return mlir::failure();
|
|
|
|
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
|
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
|
|
if (!denseAttr || !globalType || !globalType.hasStaticShape())
|
|
return mlir::failure();
|
|
|
|
auto elementType = denseAttr.getElementType();
|
|
if (!elementType.isIndex() && !elementType.isInteger())
|
|
return mlir::failure();
|
|
|
|
llvm::SmallVector<int64_t> indices;
|
|
indices.reserve(loadOp.getIndices().size());
|
|
for (mlir::Value index : loadOp.getIndices()) {
|
|
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
|
|
if (failed(resolvedIndex))
|
|
return mlir::failure();
|
|
indices.push_back(*resolvedIndex);
|
|
}
|
|
|
|
if (indices.size() != static_cast<size_t>(globalType.getRank()))
|
|
return mlir::failure();
|
|
|
|
auto strides = computeRowMajorStrides(globalType.getShape());
|
|
int64_t linearIndex = linearizeIndex(indices, strides);
|
|
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
|
|
return mlir::failure();
|
|
|
|
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
|
|
}
|
|
|
|
static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) {
|
|
switch (predicate) {
|
|
case mlir::arith::CmpIPredicate::eq: return lhs == rhs;
|
|
case mlir::arith::CmpIPredicate::ne: return lhs != rhs;
|
|
case mlir::arith::CmpIPredicate::slt: return lhs < rhs;
|
|
case mlir::arith::CmpIPredicate::sle: return lhs <= rhs;
|
|
case mlir::arith::CmpIPredicate::sgt: return lhs > rhs;
|
|
case mlir::arith::CmpIPredicate::sge: return lhs >= rhs;
|
|
case mlir::arith::CmpIPredicate::ult: return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
|
|
case mlir::arith::CmpIPredicate::ule: return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
|
|
case mlir::arith::CmpIPredicate::ugt: return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
|
|
case mlir::arith::CmpIPredicate::uge: return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
|
|
}
|
|
|
|
llvm_unreachable("unknown cmpi predicate");
|
|
}
|
|
|
|
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr,
|
|
const StaticValueKnowledge& knowledge) {
|
|
if (!expr.node)
|
|
return mlir::failure();
|
|
|
|
switch (expr.node->kind) {
|
|
case CompiledIndexExprNode::Kind::Constant: return expr.node->constant;
|
|
case CompiledIndexExprNode::Kind::Symbol: {
|
|
auto value = resolveAlias(expr.node->symbol, &knowledge);
|
|
auto iter = knowledge.indexValues.find(value);
|
|
if (iter != knowledge.indexValues.end())
|
|
return iter->second;
|
|
return mlir::failure();
|
|
}
|
|
case CompiledIndexExprNode::Kind::Add:
|
|
case CompiledIndexExprNode::Kind::Sub:
|
|
case CompiledIndexExprNode::Kind::Mul:
|
|
case CompiledIndexExprNode::Kind::DivUI:
|
|
case CompiledIndexExprNode::Kind::DivSI:
|
|
case CompiledIndexExprNode::Kind::RemUI:
|
|
case CompiledIndexExprNode::Kind::RemSI:
|
|
case CompiledIndexExprNode::Kind::MinUI:
|
|
case CompiledIndexExprNode::Kind::CmpI: {
|
|
auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
|
|
auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge);
|
|
if (failed(lhs) || failed(rhs))
|
|
return mlir::failure();
|
|
|
|
switch (expr.node->kind) {
|
|
case CompiledIndexExprNode::Kind::Add: return *lhs + *rhs;
|
|
case CompiledIndexExprNode::Kind::Sub: return *lhs - *rhs;
|
|
case CompiledIndexExprNode::Kind::Mul: return *lhs * *rhs;
|
|
case CompiledIndexExprNode::Kind::DivUI:
|
|
if (*rhs == 0)
|
|
return mlir::failure();
|
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
|
case CompiledIndexExprNode::Kind::DivSI:
|
|
if (*rhs == 0 || (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1))
|
|
return mlir::failure();
|
|
return *lhs / *rhs;
|
|
case CompiledIndexExprNode::Kind::RemUI:
|
|
if (*rhs == 0)
|
|
return mlir::failure();
|
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
|
case CompiledIndexExprNode::Kind::RemSI:
|
|
if (*rhs == 0)
|
|
return mlir::failure();
|
|
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
|
|
return 0;
|
|
return *lhs % *rhs;
|
|
case CompiledIndexExprNode::Kind::MinUI:
|
|
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
|
case CompiledIndexExprNode::Kind::CmpI: return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
|
|
default: llvm_unreachable("unexpected binary compiled index kind");
|
|
}
|
|
}
|
|
case CompiledIndexExprNode::Kind::Select: {
|
|
auto condition = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
|
|
if (failed(condition))
|
|
return mlir::failure();
|
|
return evaluateCompiledIndexExpr(*condition != 0 ? expr.node->operands[1] : expr.node->operands[2], knowledge);
|
|
}
|
|
case CompiledIndexExprNode::Kind::ConstantGlobalLoad: {
|
|
if (!expr.node->globalOp || !expr.node->globalOp.getInitialValue())
|
|
return mlir::failure();
|
|
|
|
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*expr.node->globalOp.getInitialValue());
|
|
auto globalType = mlir::dyn_cast<mlir::MemRefType>(expr.node->globalOp.getType());
|
|
if (!denseAttr || !globalType)
|
|
return mlir::failure();
|
|
|
|
llvm::SmallVector<int64_t> indices;
|
|
indices.reserve(expr.node->operands.size());
|
|
for (const CompiledIndexExpr& operand : expr.node->operands) {
|
|
auto resolvedIndex = evaluateCompiledIndexExpr(operand, knowledge);
|
|
if (failed(resolvedIndex))
|
|
return mlir::failure();
|
|
indices.push_back(*resolvedIndex);
|
|
}
|
|
|
|
int64_t linearIndex = linearizeIndex(indices, expr.node->globalStrides);
|
|
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
|
|
return mlir::failure();
|
|
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
|
|
}
|
|
}
|
|
|
|
llvm_unreachable("unknown compiled index kind");
|
|
}
|
|
|
|
llvm::FailureOr<CompiledIndexExpr> compileConstantGlobalLoad(mlir::memref::LoadOp loadOp) {
|
|
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
|
if (!getGlobalOp)
|
|
return mlir::failure();
|
|
|
|
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
|
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
|
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
|
return mlir::failure();
|
|
|
|
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
|
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
|
|
if (!denseAttr || !globalType || !globalType.hasStaticShape())
|
|
return mlir::failure();
|
|
|
|
auto elementType = denseAttr.getElementType();
|
|
if (!elementType.isIndex() && !elementType.isInteger())
|
|
return mlir::failure();
|
|
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::ConstantGlobalLoad;
|
|
expr.globalOp = globalOp;
|
|
expr.globalStrides = computeRowMajorStrides(globalType.getShape());
|
|
expr.operands.reserve(loadOp.getIndices().size());
|
|
for (mlir::Value index : loadOp.getIndices()) {
|
|
auto compiledIndex = compileIndexValueImpl(index);
|
|
if (failed(compiledIndex))
|
|
return mlir::failure();
|
|
expr.operands.push_back(*compiledIndex);
|
|
}
|
|
return makeCompiledIndexExpr(std::move(expr));
|
|
}
|
|
|
|
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value) {
|
|
if (auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>()) {
|
|
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue())) {
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
|
expr.constant = integerAttr.getInt();
|
|
return makeCompiledIndexExpr(std::move(expr));
|
|
}
|
|
}
|
|
|
|
mlir::Operation* definingOp = value.getDefiningOp();
|
|
if (!definingOp) {
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Symbol;
|
|
expr.symbol = value;
|
|
return makeCompiledIndexExpr(std::move(expr));
|
|
}
|
|
|
|
auto buildBinaryExpr = [&](CompiledIndexExprNode::Kind kind, mlir::Value lhsValue, mlir::Value rhsValue) {
|
|
auto lhs = compileIndexValueImpl(lhsValue);
|
|
auto rhs = compileIndexValueImpl(rhsValue);
|
|
if (failed(lhs) || failed(rhs))
|
|
return llvm::FailureOr<CompiledIndexExpr>(mlir::failure());
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = kind;
|
|
expr.operands = {*lhs, *rhs};
|
|
return llvm::FailureOr<CompiledIndexExpr>(makeCompiledIndexExpr(std::move(expr)));
|
|
};
|
|
|
|
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
|
|
return compileIndexValueImpl(indexCastOp.getIn());
|
|
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp))
|
|
return buildBinaryExpr(CompiledIndexExprNode::Kind::Add, addOp.getLhs(), addOp.getRhs());
|
|
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp))
|
|
return buildBinaryExpr(CompiledIndexExprNode::Kind::Sub, subOp.getLhs(), subOp.getRhs());
|
|
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp))
|
|
return buildBinaryExpr(CompiledIndexExprNode::Kind::Mul, mulOp.getLhs(), mulOp.getRhs());
|
|
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp))
|
|
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivUI, divOp.getLhs(), divOp.getRhs());
|
|
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp))
|
|
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivSI, divOp.getLhs(), divOp.getRhs());
|
|
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp))
|
|
return buildBinaryExpr(CompiledIndexExprNode::Kind::MinUI, minOp.getLhs(), minOp.getRhs());
|
|
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp))
|
|
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemUI, remOp.getLhs(), remOp.getRhs());
|
|
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp))
|
|
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemSI, remOp.getLhs(), remOp.getRhs());
|
|
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
|
|
auto expr = buildBinaryExpr(CompiledIndexExprNode::Kind::CmpI, cmpOp.getLhs(), cmpOp.getRhs());
|
|
if (failed(expr))
|
|
return mlir::failure();
|
|
auto exprNode = std::make_shared<CompiledIndexExprNode>(*expr->node);
|
|
exprNode->predicate = cmpOp.getPredicate();
|
|
return CompiledIndexExpr(exprNode);
|
|
}
|
|
if (auto maxOp = mlir::dyn_cast<mlir::arith::MaxUIOp>(definingOp)) {
|
|
auto lhs = compileIndexValueImpl(maxOp.getLhs());
|
|
auto rhs = compileIndexValueImpl(maxOp.getRhs());
|
|
if (failed(lhs) || failed(rhs))
|
|
return mlir::failure();
|
|
|
|
CompiledIndexExprNode cmpExpr;
|
|
cmpExpr.kind = CompiledIndexExprNode::Kind::CmpI;
|
|
cmpExpr.predicate = mlir::arith::CmpIPredicate::uge;
|
|
cmpExpr.operands = {*lhs, *rhs};
|
|
|
|
CompiledIndexExprNode selectExpr;
|
|
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
|
|
selectExpr.operands = {makeCompiledIndexExpr(std::move(cmpExpr)), *lhs, *rhs};
|
|
return makeCompiledIndexExpr(std::move(selectExpr));
|
|
}
|
|
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
|
|
auto condition = compileIndexValueImpl(selectOp.getCondition());
|
|
auto trueValue = compileIndexValueImpl(selectOp.getTrueValue());
|
|
auto falseValue = compileIndexValueImpl(selectOp.getFalseValue());
|
|
if (failed(condition) || failed(trueValue) || failed(falseValue))
|
|
return mlir::failure();
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Select;
|
|
expr.operands = {*condition, *trueValue, *falseValue};
|
|
return makeCompiledIndexExpr(std::move(expr));
|
|
}
|
|
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
|
|
return compileConstantGlobalLoad(loadOp);
|
|
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Symbol;
|
|
expr.symbol = value;
|
|
return makeCompiledIndexExpr(std::move(expr));
|
|
}
|
|
|
|
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
|
value = resolveAlias(value, knowledge);
|
|
|
|
if (knowledge) {
|
|
auto iter = knowledge->indexValues.find(value);
|
|
if (iter != knowledge->indexValues.end())
|
|
return iter->second;
|
|
}
|
|
|
|
auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>();
|
|
if (constantOp) {
|
|
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue()))
|
|
return integerAttr.getInt();
|
|
}
|
|
|
|
mlir::Operation* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return mlir::failure();
|
|
|
|
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
|
|
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
|
|
|
|
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp)) {
|
|
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
|
|
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
|
|
if (failed(lhs) || failed(rhs))
|
|
return mlir::failure();
|
|
return *lhs + *rhs;
|
|
}
|
|
|
|
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp)) {
|
|
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
|
|
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
|
|
if (failed(lhs) || failed(rhs))
|
|
return mlir::failure();
|
|
return *lhs - *rhs;
|
|
}
|
|
|
|
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp)) {
|
|
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
|
|
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
|
|
if (failed(lhs) || failed(rhs))
|
|
return mlir::failure();
|
|
return *lhs * *rhs;
|
|
}
|
|
|
|
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp)) {
|
|
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
|
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
|
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
|
return mlir::failure();
|
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
|
}
|
|
|
|
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp)) {
|
|
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
|
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
|
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
|
return mlir::failure();
|
|
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
|
|
return mlir::failure();
|
|
return *lhs / *rhs;
|
|
}
|
|
|
|
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
|
|
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
|
|
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
|
|
if (failed(lhs) || failed(rhs))
|
|
return mlir::failure();
|
|
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
|
}
|
|
|
|
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
|
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
|
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
|
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
|
return mlir::failure();
|
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
|
}
|
|
|
|
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp)) {
|
|
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
|
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
|
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
|
return mlir::failure();
|
|
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
|
|
return 0;
|
|
return *lhs % *rhs;
|
|
}
|
|
|
|
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
|
|
auto lhs = resolveIndexValueImpl(cmpOp.getLhs(), knowledge);
|
|
auto rhs = resolveIndexValueImpl(cmpOp.getRhs(), knowledge);
|
|
if (failed(lhs) || failed(rhs))
|
|
return mlir::failure();
|
|
return evaluateCmpPredicate(cmpOp.getPredicate(), *lhs, *rhs) ? 1 : 0;
|
|
}
|
|
|
|
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
|
|
auto condition = resolveIndexValueImpl(selectOp.getCondition(), knowledge);
|
|
if (failed(condition))
|
|
return mlir::failure();
|
|
return resolveIndexValueImpl(*condition != 0 ? selectOp.getTrueValue() : selectOp.getFalseValue(), knowledge);
|
|
}
|
|
|
|
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
|
|
return resolveConstantGlobalLoad(loadOp, knowledge);
|
|
|
|
return mlir::failure();
|
|
}
|
|
|
|
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
|
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(ofr)) {
|
|
auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
|
|
if (!integerAttr)
|
|
return mlir::failure();
|
|
return integerAttr.getInt();
|
|
}
|
|
|
|
return resolveIndexValueImpl(mlir::cast<mlir::Value>(ofr), knowledge);
|
|
}
|
|
|
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Value value,
|
|
const StaticValueKnowledge* knowledge) {
|
|
int64_t byteOffset = 0;
|
|
value = resolveAlias(value, knowledge);
|
|
|
|
while (true) {
|
|
if (mlir::isa<mlir::BlockArgument>(value))
|
|
return ResolvedContiguousAddress {value, byteOffset};
|
|
|
|
mlir::Operation* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return mlir::failure();
|
|
|
|
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
|
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
|
|
if (!tiedOperand)
|
|
return mlir::failure();
|
|
value = resolveAlias(tiedOperand->get(), knowledge);
|
|
continue;
|
|
}
|
|
|
|
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
|
|
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
|
if (!result)
|
|
return mlir::failure();
|
|
|
|
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
|
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
|
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
|
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
|
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
|
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
value = yieldedValue;
|
|
continue;
|
|
}
|
|
|
|
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
|
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
|
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
|
return mlir::failure();
|
|
|
|
llvm::SmallVector<int64_t> offsets;
|
|
llvm::SmallVector<int64_t> sizes;
|
|
llvm::SmallVector<int64_t> strides;
|
|
offsets.reserve(subviewOp.getMixedOffsets().size());
|
|
sizes.reserve(subviewOp.getMixedSizes().size());
|
|
strides.reserve(subviewOp.getMixedStrides().size());
|
|
|
|
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
|
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
|
|
if (failed(resolvedOffset))
|
|
return mlir::failure();
|
|
offsets.push_back(*resolvedOffset);
|
|
}
|
|
|
|
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
|
auto resolvedSize = resolveOpFoldResult(size, knowledge);
|
|
if (failed(resolvedSize))
|
|
return mlir::failure();
|
|
sizes.push_back(*resolvedSize);
|
|
}
|
|
|
|
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
|
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
|
|
if (failed(resolvedStride))
|
|
return mlir::failure();
|
|
strides.push_back(*resolvedStride);
|
|
}
|
|
|
|
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
|
return mlir::failure();
|
|
|
|
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
|
if (failed(sourceStrides))
|
|
return mlir::failure();
|
|
byteOffset += linearizeIndex(offsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
|
value = resolveAlias(subviewOp.getSource(), knowledge);
|
|
continue;
|
|
}
|
|
|
|
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
|
|
value = resolveAlias(castOp.getSource(), knowledge);
|
|
continue;
|
|
}
|
|
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
|
|
value = resolveAlias(collapseOp.getSrc(), knowledge);
|
|
continue;
|
|
}
|
|
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
|
|
value = resolveAlias(expandOp.getSrc(), knowledge);
|
|
continue;
|
|
}
|
|
|
|
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
|
|
return ResolvedContiguousAddress {value, byteOffset};
|
|
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
|
|
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value) {
|
|
int64_t constantByteOffset = 0;
|
|
CompiledIndexExpr byteOffsetExpr;
|
|
{
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
|
expr.constant = 0;
|
|
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
|
|
}
|
|
|
|
while (true) {
|
|
if (mlir::isa<mlir::BlockArgument>(value))
|
|
return CompiledAddressExpr {value, byteOffsetExpr};
|
|
|
|
mlir::Operation* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return mlir::failure();
|
|
|
|
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
|
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
|
|
if (!tiedOperand)
|
|
return mlir::failure();
|
|
value = tiedOperand->get();
|
|
continue;
|
|
}
|
|
|
|
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
|
|
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
|
if (!result)
|
|
return mlir::failure();
|
|
|
|
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
|
mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber());
|
|
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
|
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
|
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
|
value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1];
|
|
continue;
|
|
}
|
|
}
|
|
|
|
value = yieldedValue;
|
|
continue;
|
|
}
|
|
|
|
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
|
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
|
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
|
return mlir::failure();
|
|
|
|
llvm::SmallVector<int64_t> staticSizes;
|
|
staticSizes.reserve(subviewOp.getMixedSizes().size());
|
|
llvm::SmallVector<int64_t> staticStrides;
|
|
staticStrides.reserve(subviewOp.getMixedStrides().size());
|
|
llvm::SmallVector<int64_t> staticOffsets;
|
|
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
|
|
bool hasOnlyStaticOffsets = true;
|
|
|
|
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
|
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
|
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
|
else
|
|
hasOnlyStaticOffsets = false;
|
|
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
|
auto attr = mlir::dyn_cast<mlir::Attribute>(size);
|
|
if (!attr)
|
|
return mlir::failure();
|
|
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
|
}
|
|
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
|
auto attr = mlir::dyn_cast<mlir::Attribute>(stride);
|
|
if (!attr)
|
|
return mlir::failure();
|
|
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
|
}
|
|
|
|
if (!isContiguousSubviewWithDynamicOffsets(
|
|
sourceType.getShape(), subviewOp.getMixedOffsets(), staticSizes, staticStrides)) {
|
|
return mlir::failure();
|
|
}
|
|
|
|
if (hasOnlyStaticOffsets) {
|
|
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
|
|
return mlir::failure();
|
|
|
|
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
|
if (failed(sourceStrides))
|
|
return mlir::failure();
|
|
constantByteOffset +=
|
|
linearizeIndex(staticOffsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
|
}
|
|
else {
|
|
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
|
if (failed(sourceStrides))
|
|
return mlir::failure();
|
|
CompiledIndexExpr offsetExpr;
|
|
{
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
|
expr.constant = 0;
|
|
offsetExpr = makeCompiledIndexExpr(std::move(expr));
|
|
}
|
|
|
|
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
|
|
CompiledIndexExpr operandExpr;
|
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
|
expr.constant = mlir::cast<mlir::IntegerAttr>(attr).getInt() * sourceStride
|
|
* getElementTypeSizeInBytes(subviewType.getElementType());
|
|
operandExpr = makeCompiledIndexExpr(std::move(expr));
|
|
}
|
|
else {
|
|
auto compiledOffset = compileIndexValueImpl(mlir::cast<mlir::Value>(mixedOffset));
|
|
if (failed(compiledOffset))
|
|
return mlir::failure();
|
|
CompiledIndexExpr scaleExpr;
|
|
{
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
|
expr.constant = sourceStride * getElementTypeSizeInBytes(subviewType.getElementType());
|
|
scaleExpr = makeCompiledIndexExpr(std::move(expr));
|
|
}
|
|
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Mul;
|
|
expr.operands = {*compiledOffset, scaleExpr};
|
|
operandExpr = makeCompiledIndexExpr(std::move(expr));
|
|
}
|
|
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Add;
|
|
expr.operands = {offsetExpr, operandExpr};
|
|
offsetExpr = makeCompiledIndexExpr(std::move(expr));
|
|
}
|
|
|
|
CompiledIndexExpr constantExpr;
|
|
{
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
|
expr.constant = constantByteOffset;
|
|
constantExpr = makeCompiledIndexExpr(std::move(expr));
|
|
}
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Add;
|
|
expr.operands = {constantExpr, offsetExpr};
|
|
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
|
|
constantByteOffset = 0;
|
|
}
|
|
|
|
value = subviewOp.getSource();
|
|
continue;
|
|
}
|
|
|
|
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
|
|
value = castOp.getSource();
|
|
continue;
|
|
}
|
|
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
|
|
value = collapseOp.getSrc();
|
|
continue;
|
|
}
|
|
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
|
|
value = expandOp.getSrc();
|
|
continue;
|
|
}
|
|
|
|
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp)) {
|
|
if (constantByteOffset != 0) {
|
|
CompiledIndexExpr constantExpr;
|
|
{
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
|
expr.constant = constantByteOffset;
|
|
constantExpr = makeCompiledIndexExpr(std::move(expr));
|
|
}
|
|
if (byteOffsetExpr.node->kind == CompiledIndexExprNode::Kind::Constant && byteOffsetExpr.node->constant == 0)
|
|
byteOffsetExpr = constantExpr;
|
|
else {
|
|
CompiledIndexExprNode expr;
|
|
expr.kind = CompiledIndexExprNode::Kind::Add;
|
|
expr.operands = {constantExpr, byteOffsetExpr};
|
|
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
|
|
}
|
|
}
|
|
return CompiledAddressExpr {value, byteOffsetExpr};
|
|
}
|
|
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
|
return resolveIndexValueImpl(value, &knowledge);
|
|
}
|
|
|
|
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); }
|
|
|
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
|
const StaticValueKnowledge& knowledge) {
|
|
return resolveContiguousAddressImpl(value, &knowledge);
|
|
}
|
|
|
|
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
|
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
|
}
|
|
|
|
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value) {
|
|
return compileContiguousAddressExprImpl(value);
|
|
}
|
|
|
|
llvm::FailureOr<int64_t> CompiledIndexExpr::evaluate(const StaticValueKnowledge& knowledge) const {
|
|
return evaluateCompiledIndexExpr(*this, knowledge);
|
|
}
|
|
|
|
llvm::FailureOr<ResolvedContiguousAddress> CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge,
|
|
std::optional<unsigned> lane) const {
|
|
(void) lane;
|
|
auto resolvedOffset = byteOffset.evaluate(knowledge);
|
|
if (failed(resolvedOffset))
|
|
return mlir::failure();
|
|
return ResolvedContiguousAddress {base, *resolvedOffset};
|
|
}
|
|
|
|
} // namespace onnx_mlir
|