This commit is contained in:
@@ -40,6 +40,21 @@ Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, O
|
||||
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
||||
}
|
||||
|
||||
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, RewriterBase& rewriter) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
Block* hostBlock = getHostConstantBlock(anchorOp);
|
||||
for (Operation& op : *hostBlock) {
|
||||
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
||||
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
||||
continue;
|
||||
return constantOp.getResult();
|
||||
}
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(hostBlock);
|
||||
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
|
||||
}
|
||||
|
||||
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
|
||||
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
|
||||
}
|
||||
@@ -49,6 +64,11 @@ Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, Operation
|
||||
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
|
||||
}
|
||||
|
||||
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, RewriterBase& rewriter) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), rewriter);
|
||||
}
|
||||
|
||||
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
@@ -17,10 +14,17 @@ mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
|
||||
mlir::Type type,
|
||||
mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
|
||||
mlir::Attribute value,
|
||||
mlir::Type type,
|
||||
mlir::RewriterBase& rewriter);
|
||||
|
||||
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::RewriterBase& rewriter);
|
||||
|
||||
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||
|
||||
Reference in New Issue
Block a user