#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "ConstantUtils.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { Block* getConstantInsertionBlock(Operation* anchorOp) { assert(anchorOp && "expected a valid anchor operation"); for (Operation* current = anchorOp; current; current = current->getParentOp()) if (isa(current)) return current->getBlock(); if (auto funcOp = dyn_cast(anchorOp)) return &funcOp.getBody().front(); if (auto moduleOp = dyn_cast(anchorOp)) return moduleOp.getBody(); if (auto funcOp = anchorOp->getParentOfType()) return &funcOp.getBody().front(); if (auto moduleOp = anchorOp->getParentOfType()) return moduleOp.getBody(); return anchorOp->getBlock(); } Value getOrCreateConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) { assert(anchorOp && "expected a valid anchor operation"); Block* hostBlock = getConstantInsertionBlock(anchorOp); for (Operation& op : *hostBlock) { auto constantOp = dyn_cast(&op); if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value) continue; return constantOp.getResult(); } auto* arithDialect = anchorOp->getContext()->getOrLoadDialect(); return folder.getOrCreateConstant(hostBlock, arithDialect, value, type); } Value getOrCreateConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) { assert(anchorOp && "expected a valid anchor operation"); Block* hostBlock = getConstantInsertionBlock(anchorOp); for (Operation& op : *hostBlock) { auto constantOp = dyn_cast(&op); if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value) continue; return constantOp.getResult(); } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(hostBlock); return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast(value)).getResult(); } Value getOrCreateConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) { return getOrCreateConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType()); } Value getOrCreateIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) { Builder builder(anchorOp->getContext()); return getOrCreateConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType()); } Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) { Builder builder(anchorOp->getContext()); return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType()); } std::optional matchConstantIndexValue(Value value) { if (!value || !value.getType().isIndex()) return std::nullopt; if (auto constant = value.getDefiningOp()) return constant.value(); if (auto constant = value.getDefiningOp()) if (auto intAttr = dyn_cast(constant.getValue()); intAttr && intAttr.getType().isIndex()) return intAttr.getInt(); return std::nullopt; } std::optional matchConstantIndexValue(OpFoldResult value) { if (auto attr = dyn_cast(value)) if (auto intAttr = dyn_cast(attr); intAttr && intAttr.getType().isIndex()) return intAttr.getInt(); if (auto operand = dyn_cast(value)) return matchConstantIndexValue(operand); return std::nullopt; } } // namespace onnx_mlir