uniquify constants produced by affine lowering
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -10,6 +10,17 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
static std::optional<int64_t> getIndexConstantValue(arith::ConstantOp constantOp) {
|
||||
if (!constantOp.getType().isIndex())
|
||||
return std::nullopt;
|
||||
|
||||
auto intAttr = dyn_cast<IntegerAttr>(constantOp.getValue());
|
||||
if (!intAttr || !intAttr.getType().isIndex())
|
||||
return std::nullopt;
|
||||
|
||||
return intAttr.getInt();
|
||||
}
|
||||
|
||||
Block* getConstantInsertionBlock(Operation* anchorOp) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
|
||||
@@ -67,6 +78,59 @@ Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int6
|
||||
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
|
||||
}
|
||||
|
||||
void hoistAndUniquifyIndexConstants(func::FuncOp funcOp, RewriterBase& rewriter) {
|
||||
if (funcOp.getBody().empty())
|
||||
return;
|
||||
|
||||
Block& entryBlock = funcOp.getBody().front();
|
||||
DenseMap<int64_t, Value> canonicalByValue;
|
||||
SmallVector<arith::ConstantOp> constants;
|
||||
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
if (!getIndexConstantValue(constantOp))
|
||||
return;
|
||||
constants.push_back(constantOp);
|
||||
});
|
||||
|
||||
for (arith::ConstantOp constantOp : constants) {
|
||||
auto value = getIndexConstantValue(constantOp);
|
||||
if (!value || constantOp->getBlock() != &entryBlock)
|
||||
continue;
|
||||
canonicalByValue.try_emplace(*value, constantOp.getResult());
|
||||
}
|
||||
|
||||
for (arith::ConstantOp constantOp : constants) {
|
||||
auto value = getIndexConstantValue(constantOp);
|
||||
if (!value)
|
||||
continue;
|
||||
|
||||
Value canonical = canonicalByValue.lookup(*value);
|
||||
if (!canonical) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&entryBlock);
|
||||
Builder builder(funcOp.getContext());
|
||||
canonical =
|
||||
arith::ConstantOp::create(rewriter, constantOp.getLoc(), builder.getIndexType(), builder.getIndexAttr(*value));
|
||||
canonicalByValue[*value] = canonical;
|
||||
}
|
||||
|
||||
if (constantOp.getResult() == canonical)
|
||||
continue;
|
||||
|
||||
constantOp.getResult().replaceAllUsesWith(canonical);
|
||||
}
|
||||
|
||||
for (arith::ConstantOp constantOp : llvm::reverse(constants)) {
|
||||
auto value = getIndexConstantValue(constantOp);
|
||||
if (!value)
|
||||
continue;
|
||||
if (constantOp.getResult() == canonicalByValue.lookup(*value))
|
||||
continue;
|
||||
if (constantOp.use_empty())
|
||||
rewriter.eraseOp(constantOp);
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<int64_t> matchConstantIndexValue(Value value) {
|
||||
if (!value || !value.getType().isIndex())
|
||||
return std::nullopt;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
@@ -23,6 +24,8 @@ mlir::Value getOrCreateIndexConstant(mlir::OperationFolder& folder, mlir::Operat
|
||||
|
||||
mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
|
||||
|
||||
void hoistAndUniquifyIndexConstants(mlir::func::FuncOp funcOp, mlir::RewriterBase& rewriter);
|
||||
|
||||
std::optional<int64_t> matchConstantIndexValue(mlir::Value value);
|
||||
|
||||
std::optional<int64_t> matchConstantIndexValue(mlir::OpFoldResult value);
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Common/IR/ConstantUtils.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/Patterns.hpp"
|
||||
@@ -264,6 +265,8 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
hoistAndUniquifyIndexConstants(funcOp, rewriter);
|
||||
|
||||
// Dump to file for debug
|
||||
dumpModule(moduleOp, "pim0");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user