uniquify constants produced by affine lowering
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-06-01 10:52:25 +02:00
parent b678e55d3c
commit 356be6ccc2
3 changed files with 70 additions and 0 deletions
+64
View File
@@ -10,6 +10,17 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
static std::optional<int64_t> getIndexConstantValue(arith::ConstantOp constantOp) {
if (!constantOp.getType().isIndex())
return std::nullopt;
auto intAttr = dyn_cast<IntegerAttr>(constantOp.getValue());
if (!intAttr || !intAttr.getType().isIndex())
return std::nullopt;
return intAttr.getInt();
}
Block* getConstantInsertionBlock(Operation* anchorOp) { Block* getConstantInsertionBlock(Operation* anchorOp) {
assert(anchorOp && "expected a valid anchor operation"); assert(anchorOp && "expected a valid anchor operation");
@@ -67,6 +78,59 @@ Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int6
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType()); return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
} }
void hoistAndUniquifyIndexConstants(func::FuncOp funcOp, RewriterBase& rewriter) {
if (funcOp.getBody().empty())
return;
Block& entryBlock = funcOp.getBody().front();
DenseMap<int64_t, Value> canonicalByValue;
SmallVector<arith::ConstantOp> constants;
funcOp.walk([&](arith::ConstantOp constantOp) {
if (!getIndexConstantValue(constantOp))
return;
constants.push_back(constantOp);
});
for (arith::ConstantOp constantOp : constants) {
auto value = getIndexConstantValue(constantOp);
if (!value || constantOp->getBlock() != &entryBlock)
continue;
canonicalByValue.try_emplace(*value, constantOp.getResult());
}
for (arith::ConstantOp constantOp : constants) {
auto value = getIndexConstantValue(constantOp);
if (!value)
continue;
Value canonical = canonicalByValue.lookup(*value);
if (!canonical) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&entryBlock);
Builder builder(funcOp.getContext());
canonical =
arith::ConstantOp::create(rewriter, constantOp.getLoc(), builder.getIndexType(), builder.getIndexAttr(*value));
canonicalByValue[*value] = canonical;
}
if (constantOp.getResult() == canonical)
continue;
constantOp.getResult().replaceAllUsesWith(canonical);
}
for (arith::ConstantOp constantOp : llvm::reverse(constants)) {
auto value = getIndexConstantValue(constantOp);
if (!value)
continue;
if (constantOp.getResult() == canonicalByValue.lookup(*value))
continue;
if (constantOp.use_empty())
rewriter.eraseOp(constantOp);
}
}
std::optional<int64_t> matchConstantIndexValue(Value value) { std::optional<int64_t> matchConstantIndexValue(Value value) {
if (!value || !value.getType().isIndex()) if (!value || !value.getType().isIndex())
return std::nullopt; return std::nullopt;
+3
View File
@@ -1,6 +1,7 @@
#pragma once #pragma once
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/FoldUtils.h"
@@ -23,6 +24,8 @@ mlir::Value getOrCreateIndexConstant(mlir::OperationFolder& folder, mlir::Operat
mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value); mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
void hoistAndUniquifyIndexConstants(mlir::func::FuncOp funcOp, mlir::RewriterBase& rewriter);
std::optional<int64_t> matchConstantIndexValue(mlir::Value value); std::optional<int64_t> matchConstantIndexValue(mlir::Value value);
std::optional<int64_t> matchConstantIndexValue(mlir::OpFoldResult value); std::optional<int64_t> matchConstantIndexValue(mlir::OpFoldResult value);
@@ -26,6 +26,7 @@
#include <utility> #include <utility>
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Common/IR/ConstantUtils.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/Common.hpp" #include "Conversion/SpatialToPim/Common.hpp"
#include "Conversion/SpatialToPim/Patterns.hpp" #include "Conversion/SpatialToPim/Patterns.hpp"
@@ -264,6 +265,8 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
signalPassFailure(); signalPassFailure();
return; return;
} }
hoistAndUniquifyIndexConstants(funcOp, rewriter);
// Dump to file for debug // Dump to file for debug
dumpModule(moduleOp, "pim0"); dumpModule(moduleOp, "pim0");
} }