automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-27 16:39:56 +02:00
parent 4bdaa57656
commit 874a2f53e6
23 changed files with 136 additions and 198 deletions
+25 -43
View File
@@ -111,39 +111,29 @@ static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp l
static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) {
switch (predicate) {
case mlir::arith::CmpIPredicate::eq:
return lhs == rhs;
case mlir::arith::CmpIPredicate::ne:
return lhs != rhs;
case mlir::arith::CmpIPredicate::slt:
return lhs < rhs;
case mlir::arith::CmpIPredicate::sle:
return lhs <= rhs;
case mlir::arith::CmpIPredicate::sgt:
return lhs > rhs;
case mlir::arith::CmpIPredicate::sge:
return lhs >= rhs;
case mlir::arith::CmpIPredicate::ult:
return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ule:
return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ugt:
return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::uge:
return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::eq: return lhs == rhs;
case mlir::arith::CmpIPredicate::ne: return lhs != rhs;
case mlir::arith::CmpIPredicate::slt: return lhs < rhs;
case mlir::arith::CmpIPredicate::sle: return lhs <= rhs;
case mlir::arith::CmpIPredicate::sgt: return lhs > rhs;
case mlir::arith::CmpIPredicate::sge: return lhs >= rhs;
case mlir::arith::CmpIPredicate::ult: return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ule: return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ugt: return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::uge: return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
}
llvm_unreachable("unknown cmpi predicate");
}
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr, const StaticValueKnowledge& knowledge) {
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr,
const StaticValueKnowledge& knowledge) {
if (!expr.node)
return mlir::failure();
switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Constant:
return expr.node->constant;
case CompiledIndexExprNode::Kind::Symbol: {
case CompiledIndexExprNode::Kind::Constant: return expr.node->constant;
case CompiledIndexExprNode::Kind::Symbol: {
auto value = resolveAlias(expr.node->symbol, &knowledge);
auto iter = knowledge.indexValues.find(value);
if (iter != knowledge.indexValues.end())
@@ -158,19 +148,16 @@ llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr
case CompiledIndexExprNode::Kind::RemUI:
case CompiledIndexExprNode::Kind::RemSI:
case CompiledIndexExprNode::Kind::MinUI:
case CompiledIndexExprNode::Kind::CmpI: {
case CompiledIndexExprNode::Kind::CmpI: {
auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Add:
return *lhs + *rhs;
case CompiledIndexExprNode::Kind::Sub:
return *lhs - *rhs;
case CompiledIndexExprNode::Kind::Mul:
return *lhs * *rhs;
case CompiledIndexExprNode::Kind::Add: return *lhs + *rhs;
case CompiledIndexExprNode::Kind::Sub: return *lhs - *rhs;
case CompiledIndexExprNode::Kind::Mul: return *lhs * *rhs;
case CompiledIndexExprNode::Kind::DivUI:
if (*rhs == 0)
return mlir::failure();
@@ -191,10 +178,8 @@ llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr
return *lhs % *rhs;
case CompiledIndexExprNode::Kind::MinUI:
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
case CompiledIndexExprNode::Kind::CmpI:
return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
default:
llvm_unreachable("unexpected binary compiled index kind");
case CompiledIndexExprNode::Kind::CmpI: return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
default: llvm_unreachable("unexpected binary compiled index kind");
}
}
case CompiledIndexExprNode::Kind::Select: {
@@ -639,24 +624,21 @@ llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Valu
staticStrides.reserve(subviewOp.getMixedStrides().size());
bool allStatic = true;
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
}
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
for (mlir::OpFoldResult size : subviewOp.getMixedSizes())
if (auto attr = mlir::dyn_cast<mlir::Attribute>(size))
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
}
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides())
if (auto attr = mlir::dyn_cast<mlir::Attribute>(stride))
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
allStatic = false;
}
if (allStatic) {
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
@@ -796,8 +778,8 @@ llvm::FailureOr<int64_t> CompiledIndexExpr::evaluate(const StaticValueKnowledge&
return evaluateCompiledIndexExpr(*this, knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress>
CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const {
llvm::FailureOr<ResolvedContiguousAddress> CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge,
std::optional<unsigned> lane) const {
(void) lane;
auto resolvedOffset = byteOffset.evaluate(knowledge);
if (failed(resolvedOffset))