uniquify constants produced by affine lowering
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -10,6 +10,17 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
static std::optional<int64_t> getIndexConstantValue(arith::ConstantOp constantOp) {
|
||||
if (!constantOp.getType().isIndex())
|
||||
return std::nullopt;
|
||||
|
||||
auto intAttr = dyn_cast<IntegerAttr>(constantOp.getValue());
|
||||
if (!intAttr || !intAttr.getType().isIndex())
|
||||
return std::nullopt;
|
||||
|
||||
return intAttr.getInt();
|
||||
}
|
||||
|
||||
Block* getConstantInsertionBlock(Operation* anchorOp) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
|
||||
@@ -67,6 +78,59 @@ Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int6
|
||||
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
|
||||
}
|
||||
|
||||
void hoistAndUniquifyIndexConstants(func::FuncOp funcOp, RewriterBase& rewriter) {
|
||||
if (funcOp.getBody().empty())
|
||||
return;
|
||||
|
||||
Block& entryBlock = funcOp.getBody().front();
|
||||
DenseMap<int64_t, Value> canonicalByValue;
|
||||
SmallVector<arith::ConstantOp> constants;
|
||||
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
if (!getIndexConstantValue(constantOp))
|
||||
return;
|
||||
constants.push_back(constantOp);
|
||||
});
|
||||
|
||||
for (arith::ConstantOp constantOp : constants) {
|
||||
auto value = getIndexConstantValue(constantOp);
|
||||
if (!value || constantOp->getBlock() != &entryBlock)
|
||||
continue;
|
||||
canonicalByValue.try_emplace(*value, constantOp.getResult());
|
||||
}
|
||||
|
||||
for (arith::ConstantOp constantOp : constants) {
|
||||
auto value = getIndexConstantValue(constantOp);
|
||||
if (!value)
|
||||
continue;
|
||||
|
||||
Value canonical = canonicalByValue.lookup(*value);
|
||||
if (!canonical) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&entryBlock);
|
||||
Builder builder(funcOp.getContext());
|
||||
canonical =
|
||||
arith::ConstantOp::create(rewriter, constantOp.getLoc(), builder.getIndexType(), builder.getIndexAttr(*value));
|
||||
canonicalByValue[*value] = canonical;
|
||||
}
|
||||
|
||||
if (constantOp.getResult() == canonical)
|
||||
continue;
|
||||
|
||||
constantOp.getResult().replaceAllUsesWith(canonical);
|
||||
}
|
||||
|
||||
for (arith::ConstantOp constantOp : llvm::reverse(constants)) {
|
||||
auto value = getIndexConstantValue(constantOp);
|
||||
if (!value)
|
||||
continue;
|
||||
if (constantOp.getResult() == canonicalByValue.lookup(*value))
|
||||
continue;
|
||||
if (constantOp.use_empty())
|
||||
rewriter.eraseOp(constantOp);
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<int64_t> matchConstantIndexValue(Value value) {
|
||||
if (!value || !value.getType().isIndex())
|
||||
return std::nullopt;
|
||||
|
||||
Reference in New Issue
Block a user