diff --git a/src/PIM/Common/IR/ConstantUtils.cpp b/src/PIM/Common/IR/ConstantUtils.cpp index fe21409..47fd66a 100644 --- a/src/PIM/Common/IR/ConstantUtils.cpp +++ b/src/PIM/Common/IR/ConstantUtils.cpp @@ -10,6 +10,17 @@ using namespace mlir; namespace onnx_mlir { +static std::optional getIndexConstantValue(arith::ConstantOp constantOp) { + if (!constantOp.getType().isIndex()) + return std::nullopt; + + auto intAttr = dyn_cast(constantOp.getValue()); + if (!intAttr || !intAttr.getType().isIndex()) + return std::nullopt; + + return intAttr.getInt(); +} + Block* getConstantInsertionBlock(Operation* anchorOp) { assert(anchorOp && "expected a valid anchor operation"); @@ -67,6 +78,59 @@ Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int6 return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType()); } +void hoistAndUniquifyIndexConstants(func::FuncOp funcOp, RewriterBase& rewriter) { + if (funcOp.getBody().empty()) + return; + + Block& entryBlock = funcOp.getBody().front(); + DenseMap canonicalByValue; + SmallVector constants; + + funcOp.walk([&](arith::ConstantOp constantOp) { + if (!getIndexConstantValue(constantOp)) + return; + constants.push_back(constantOp); + }); + + for (arith::ConstantOp constantOp : constants) { + auto value = getIndexConstantValue(constantOp); + if (!value || constantOp->getBlock() != &entryBlock) + continue; + canonicalByValue.try_emplace(*value, constantOp.getResult()); + } + + for (arith::ConstantOp constantOp : constants) { + auto value = getIndexConstantValue(constantOp); + if (!value) + continue; + + Value canonical = canonicalByValue.lookup(*value); + if (!canonical) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&entryBlock); + Builder builder(funcOp.getContext()); + canonical = + arith::ConstantOp::create(rewriter, constantOp.getLoc(), builder.getIndexType(), builder.getIndexAttr(*value)); + canonicalByValue[*value] = canonical; + } + + if (constantOp.getResult() == canonical) + continue; + + constantOp.getResult().replaceAllUsesWith(canonical); + } + + for (arith::ConstantOp constantOp : llvm::reverse(constants)) { + auto value = getIndexConstantValue(constantOp); + if (!value) + continue; + if (constantOp.getResult() == canonicalByValue.lookup(*value)) + continue; + if (constantOp.use_empty()) + rewriter.eraseOp(constantOp); + } +} + std::optional matchConstantIndexValue(Value value) { if (!value || !value.getType().isIndex()) return std::nullopt; diff --git a/src/PIM/Common/IR/ConstantUtils.hpp b/src/PIM/Common/IR/ConstantUtils.hpp index a0a96f5..fc8be72 100644 --- a/src/PIM/Common/IR/ConstantUtils.hpp +++ b/src/PIM/Common/IR/ConstantUtils.hpp @@ -1,6 +1,7 @@ #pragma once #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/FoldUtils.h" @@ -23,6 +24,8 @@ mlir::Value getOrCreateIndexConstant(mlir::OperationFolder& folder, mlir::Operat mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value); +void hoistAndUniquifyIndexConstants(mlir::func::FuncOp funcOp, mlir::RewriterBase& rewriter); + std::optional matchConstantIndexValue(mlir::Value value); std::optional matchConstantIndexValue(mlir::OpFoldResult value); diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index b3a85aa..957e522 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -26,6 +26,7 @@ #include #include "Common/PimCommon.hpp" +#include "Common/IR/ConstantUtils.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/SpatialToPim/Common.hpp" #include "Conversion/SpatialToPim/Patterns.hpp" @@ -264,6 +265,8 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { signalPassFailure(); return; } + hoistAndUniquifyIndexConstants(funcOp, rewriter); + // Dump to file for debug dumpModule(moduleOp, "pim0"); }