Files
Raptor/src/PIM/Common/IR/ConstantUtils.cpp
T
NiccoloN 356be6ccc2
Validate Operations / validate-operations (push) Has been cancelled
uniquify constants produced by affine lowering
2026-06-01 10:52:25 +02:00

158 lines
5.4 KiB
C++

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "ConstantUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
static std::optional<int64_t> getIndexConstantValue(arith::ConstantOp constantOp) {
if (!constantOp.getType().isIndex())
return std::nullopt;
auto intAttr = dyn_cast<IntegerAttr>(constantOp.getValue());
if (!intAttr || !intAttr.getType().isIndex())
return std::nullopt;
return intAttr.getInt();
}
Block* getConstantInsertionBlock(Operation* anchorOp) {
assert(anchorOp && "expected a valid anchor operation");
if (auto funcOp = dyn_cast<func::FuncOp>(anchorOp))
return &funcOp.getBody().front();
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
return &funcOp.getBody().front();
if (auto moduleOp = dyn_cast<ModuleOp>(anchorOp))
return moduleOp.getBody();
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
return moduleOp.getBody();
return anchorOp->getBlock();
}
Value getOrCreateConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getConstantInsertionBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
continue;
return constantOp.getResult();
}
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
}
Value getOrCreateConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getConstantInsertionBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
continue;
return constantOp.getResult();
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(hostBlock);
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
}
Value getOrCreateConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) {
return getOrCreateConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType());
}
Value getOrCreateIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) {
Builder builder(anchorOp->getContext());
return getOrCreateConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
}
Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) {
Builder builder(anchorOp->getContext());
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
}
void hoistAndUniquifyIndexConstants(func::FuncOp funcOp, RewriterBase& rewriter) {
if (funcOp.getBody().empty())
return;
Block& entryBlock = funcOp.getBody().front();
DenseMap<int64_t, Value> canonicalByValue;
SmallVector<arith::ConstantOp> constants;
funcOp.walk([&](arith::ConstantOp constantOp) {
if (!getIndexConstantValue(constantOp))
return;
constants.push_back(constantOp);
});
for (arith::ConstantOp constantOp : constants) {
auto value = getIndexConstantValue(constantOp);
if (!value || constantOp->getBlock() != &entryBlock)
continue;
canonicalByValue.try_emplace(*value, constantOp.getResult());
}
for (arith::ConstantOp constantOp : constants) {
auto value = getIndexConstantValue(constantOp);
if (!value)
continue;
Value canonical = canonicalByValue.lookup(*value);
if (!canonical) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&entryBlock);
Builder builder(funcOp.getContext());
canonical =
arith::ConstantOp::create(rewriter, constantOp.getLoc(), builder.getIndexType(), builder.getIndexAttr(*value));
canonicalByValue[*value] = canonical;
}
if (constantOp.getResult() == canonical)
continue;
constantOp.getResult().replaceAllUsesWith(canonical);
}
for (arith::ConstantOp constantOp : llvm::reverse(constants)) {
auto value = getIndexConstantValue(constantOp);
if (!value)
continue;
if (constantOp.getResult() == canonicalByValue.lookup(*value))
continue;
if (constantOp.use_empty())
rewriter.eraseOp(constantOp);
}
}
std::optional<int64_t> matchConstantIndexValue(Value value) {
if (!value || !value.getType().isIndex())
return std::nullopt;
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
return constant.value();
if (auto constant = value.getDefiningOp<arith::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()); intAttr && intAttr.getType().isIndex())
return intAttr.getInt();
return std::nullopt;
}
std::optional<int64_t> matchConstantIndexValue(OpFoldResult value) {
if (auto attr = dyn_cast<Attribute>(value))
if (auto intAttr = dyn_cast<IntegerAttr>(attr); intAttr && intAttr.getType().isIndex())
return intAttr.getInt();
if (auto operand = dyn_cast<Value>(value))
return matchConstantIndexValue(operand);
return std::nullopt;
}
} // namespace onnx_mlir