normalize affine arithmetic helpers
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-30 16:37:28 +02:00
parent 7c3943bd06
commit ab63498f3f
14 changed files with 340 additions and 278 deletions
+19 -18
View File
@@ -1,10 +1,8 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Matchers.h"
#include "ConstantUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -75,24 +73,27 @@ Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int6
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
}
Value createAffineApplyOrFoldedConstant(
RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* anchorOp) {
SmallVector<Attribute> operandConstants;
operandConstants.reserve(operands.size());
for (Value operand : operands) {
APInt constantValue;
if (!matchPattern(operand, m_ConstantInt(&constantValue)))
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
operandConstants.push_back(rewriter.getIndexAttr(constantValue.getSExtValue()));
}
std::optional<int64_t> matchConstantIndexValue(Value value) {
if (!value || !value.getType().isIndex())
return std::nullopt;
SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
return getOrCreateIndexConstant(rewriter, anchorOp, constantResult.getInt());
}
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
return constant.value();
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
if (auto constant = value.getDefiningOp<arith::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()); intAttr && intAttr.getType().isIndex())
return intAttr.getInt();
return std::nullopt;
}
std::optional<int64_t> matchConstantIndexValue(OpFoldResult value) {
if (auto attr = dyn_cast<Attribute>(value))
if (auto intAttr = dyn_cast<IntegerAttr>(attr); intAttr && intAttr.getType().isIndex())
return intAttr.getInt();
if (auto operand = dyn_cast<Value>(value))
return matchConstantIndexValue(operand);
return std::nullopt;
}
} // namespace onnx_mlir