#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "ConstantUtils.hpp" using namespace mlir; namespace onnx_mlir { static std::optional getIndexConstantValue(arith::ConstantOp constantOp) { if (!constantOp.getType().isIndex()) return std::nullopt; auto intAttr = dyn_cast(constantOp.getValue()); if (!intAttr || !intAttr.getType().isIndex()) return std::nullopt; return intAttr.getInt(); } Block* getConstantInsertionBlock(Operation* anchorOp) { assert(anchorOp && "expected a valid anchor operation"); if (auto funcOp = dyn_cast(anchorOp)) return &funcOp.getBody().front(); if (auto funcOp = anchorOp->getParentOfType()) return &funcOp.getBody().front(); if (auto moduleOp = dyn_cast(anchorOp)) return moduleOp.getBody(); if (auto moduleOp = anchorOp->getParentOfType()) return moduleOp.getBody(); return anchorOp->getBlock(); } Value getOrCreateConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) { assert(anchorOp && "expected a valid anchor operation"); Block* hostBlock = getConstantInsertionBlock(anchorOp); for (Operation& op : *hostBlock) { auto constantOp = dyn_cast(&op); if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value) continue; return constantOp.getResult(); } auto* arithDialect = anchorOp->getContext()->getOrLoadDialect(); return folder.getOrCreateConstant(hostBlock, arithDialect, value, type); } Value getOrCreateConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) { assert(anchorOp && "expected a valid anchor operation"); Block* hostBlock = getConstantInsertionBlock(anchorOp); for (Operation& op : *hostBlock) { auto constantOp = dyn_cast(&op); if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value) continue; return constantOp.getResult(); } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(hostBlock); return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast(value)).getResult(); } Value getOrCreateConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) { return getOrCreateConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType()); } Value getOrCreateIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) { Builder builder(anchorOp->getContext()); return getOrCreateConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType()); } Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) { Builder builder(anchorOp->getContext()); return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType()); } void hoistAndUniquifyIndexConstants(func::FuncOp funcOp, RewriterBase& rewriter) { if (funcOp.getBody().empty()) return; Block& entryBlock = funcOp.getBody().front(); DenseMap canonicalByValue; SmallVector constants; funcOp.walk([&](arith::ConstantOp constantOp) { if (!getIndexConstantValue(constantOp)) return; constants.push_back(constantOp); }); for (arith::ConstantOp constantOp : constants) { auto value = getIndexConstantValue(constantOp); if (!value || constantOp->getBlock() != &entryBlock) continue; canonicalByValue.try_emplace(*value, constantOp.getResult()); } for (arith::ConstantOp constantOp : constants) { auto value = getIndexConstantValue(constantOp); if (!value) continue; Value canonical = canonicalByValue.lookup(*value); if (!canonical) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&entryBlock); Builder builder(funcOp.getContext()); canonical = arith::ConstantOp::create(rewriter, constantOp.getLoc(), builder.getIndexType(), builder.getIndexAttr(*value)); canonicalByValue[*value] = canonical; } if (constantOp.getResult() == canonical) continue; constantOp.getResult().replaceAllUsesWith(canonical); } for (arith::ConstantOp constantOp : llvm::reverse(constants)) { auto value = getIndexConstantValue(constantOp); if (!value) continue; if (constantOp.getResult() == canonicalByValue.lookup(*value)) continue; if (constantOp.use_empty()) rewriter.eraseOp(constantOp); } } std::optional matchConstantIndexValue(Value value) { if (!value || !value.getType().isIndex()) return std::nullopt; if (auto constant = value.getDefiningOp()) return constant.value(); if (auto constant = value.getDefiningOp()) if (auto intAttr = dyn_cast(constant.getValue()); intAttr && intAttr.getType().isIndex()) return intAttr.getInt(); return std::nullopt; } std::optional matchConstantIndexValue(OpFoldResult value) { if (auto attr = dyn_cast(value)) if (auto intAttr = dyn_cast(attr); intAttr && intAttr.getType().isIndex()) return intAttr.getInt(); if (auto operand = dyn_cast(value)) return matchConstantIndexValue(operand); return std::nullopt; } } // namespace onnx_mlir