uniquify constants produced by affine lowering
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -10,6 +10,17 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
static std::optional<int64_t> getIndexConstantValue(arith::ConstantOp constantOp) {
|
||||||
|
if (!constantOp.getType().isIndex())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
auto intAttr = dyn_cast<IntegerAttr>(constantOp.getValue());
|
||||||
|
if (!intAttr || !intAttr.getType().isIndex())
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
return intAttr.getInt();
|
||||||
|
}
|
||||||
|
|
||||||
Block* getConstantInsertionBlock(Operation* anchorOp) {
|
Block* getConstantInsertionBlock(Operation* anchorOp) {
|
||||||
assert(anchorOp && "expected a valid anchor operation");
|
assert(anchorOp && "expected a valid anchor operation");
|
||||||
|
|
||||||
@@ -67,6 +78,59 @@ Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int6
|
|||||||
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
|
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void hoistAndUniquifyIndexConstants(func::FuncOp funcOp, RewriterBase& rewriter) {
|
||||||
|
if (funcOp.getBody().empty())
|
||||||
|
return;
|
||||||
|
|
||||||
|
Block& entryBlock = funcOp.getBody().front();
|
||||||
|
DenseMap<int64_t, Value> canonicalByValue;
|
||||||
|
SmallVector<arith::ConstantOp> constants;
|
||||||
|
|
||||||
|
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||||
|
if (!getIndexConstantValue(constantOp))
|
||||||
|
return;
|
||||||
|
constants.push_back(constantOp);
|
||||||
|
});
|
||||||
|
|
||||||
|
for (arith::ConstantOp constantOp : constants) {
|
||||||
|
auto value = getIndexConstantValue(constantOp);
|
||||||
|
if (!value || constantOp->getBlock() != &entryBlock)
|
||||||
|
continue;
|
||||||
|
canonicalByValue.try_emplace(*value, constantOp.getResult());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (arith::ConstantOp constantOp : constants) {
|
||||||
|
auto value = getIndexConstantValue(constantOp);
|
||||||
|
if (!value)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Value canonical = canonicalByValue.lookup(*value);
|
||||||
|
if (!canonical) {
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(&entryBlock);
|
||||||
|
Builder builder(funcOp.getContext());
|
||||||
|
canonical =
|
||||||
|
arith::ConstantOp::create(rewriter, constantOp.getLoc(), builder.getIndexType(), builder.getIndexAttr(*value));
|
||||||
|
canonicalByValue[*value] = canonical;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (constantOp.getResult() == canonical)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
constantOp.getResult().replaceAllUsesWith(canonical);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (arith::ConstantOp constantOp : llvm::reverse(constants)) {
|
||||||
|
auto value = getIndexConstantValue(constantOp);
|
||||||
|
if (!value)
|
||||||
|
continue;
|
||||||
|
if (constantOp.getResult() == canonicalByValue.lookup(*value))
|
||||||
|
continue;
|
||||||
|
if (constantOp.use_empty())
|
||||||
|
rewriter.eraseOp(constantOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::optional<int64_t> matchConstantIndexValue(Value value) {
|
std::optional<int64_t> matchConstantIndexValue(Value value) {
|
||||||
if (!value || !value.getType().isIndex())
|
if (!value || !value.getType().isIndex())
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Transforms/FoldUtils.h"
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
@@ -23,6 +24,8 @@ mlir::Value getOrCreateIndexConstant(mlir::OperationFolder& folder, mlir::Operat
|
|||||||
|
|
||||||
mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
|
mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
|
||||||
|
|
||||||
|
void hoistAndUniquifyIndexConstants(mlir::func::FuncOp funcOp, mlir::RewriterBase& rewriter);
|
||||||
|
|
||||||
std::optional<int64_t> matchConstantIndexValue(mlir::Value value);
|
std::optional<int64_t> matchConstantIndexValue(mlir::Value value);
|
||||||
|
|
||||||
std::optional<int64_t> matchConstantIndexValue(mlir::OpFoldResult value);
|
std::optional<int64_t> matchConstantIndexValue(mlir::OpFoldResult value);
|
||||||
|
|||||||
@@ -26,6 +26,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
|
#include "Common/IR/ConstantUtils.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "Conversion/SpatialToPim/Common.hpp"
|
#include "Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "Conversion/SpatialToPim/Patterns.hpp"
|
#include "Conversion/SpatialToPim/Patterns.hpp"
|
||||||
@@ -264,6 +265,8 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
hoistAndUniquifyIndexConstants(funcOp, rewriter);
|
||||||
|
|
||||||
// Dump to file for debug
|
// Dump to file for debug
|
||||||
dumpModule(moduleOp, "pim0");
|
dumpModule(moduleOp, "pim0");
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user