100 lines
3.8 KiB
C++
100 lines
3.8 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
|
|
#include "ConstantUtils.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
Block* getConstantInsertionBlock(Operation* anchorOp) {
|
|
assert(anchorOp && "expected a valid anchor operation");
|
|
|
|
for (Operation* current = anchorOp; current; current = current->getParentOp())
|
|
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
|
|
return current->getBlock();
|
|
|
|
if (auto funcOp = dyn_cast<func::FuncOp>(anchorOp))
|
|
return &funcOp.getBody().front();
|
|
if (auto moduleOp = dyn_cast<ModuleOp>(anchorOp))
|
|
return moduleOp.getBody();
|
|
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
|
|
return &funcOp.getBody().front();
|
|
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
|
|
return moduleOp.getBody();
|
|
return anchorOp->getBlock();
|
|
}
|
|
|
|
Value getOrCreateConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) {
|
|
assert(anchorOp && "expected a valid anchor operation");
|
|
Block* hostBlock = getConstantInsertionBlock(anchorOp);
|
|
for (Operation& op : *hostBlock) {
|
|
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
|
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
|
continue;
|
|
return constantOp.getResult();
|
|
}
|
|
|
|
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
|
|
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
|
}
|
|
|
|
Value getOrCreateConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) {
|
|
assert(anchorOp && "expected a valid anchor operation");
|
|
Block* hostBlock = getConstantInsertionBlock(anchorOp);
|
|
for (Operation& op : *hostBlock) {
|
|
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
|
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
|
continue;
|
|
return constantOp.getResult();
|
|
}
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointToStart(hostBlock);
|
|
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
|
|
}
|
|
|
|
Value getOrCreateConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) {
|
|
return getOrCreateConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType());
|
|
}
|
|
|
|
Value getOrCreateIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) {
|
|
Builder builder(anchorOp->getContext());
|
|
return getOrCreateConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
|
|
}
|
|
|
|
Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) {
|
|
Builder builder(anchorOp->getContext());
|
|
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
|
|
}
|
|
|
|
std::optional<int64_t> matchConstantIndexValue(Value value) {
|
|
if (!value || !value.getType().isIndex())
|
|
return std::nullopt;
|
|
|
|
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
|
|
return constant.value();
|
|
|
|
if (auto constant = value.getDefiningOp<arith::ConstantOp>())
|
|
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()); intAttr && intAttr.getType().isIndex())
|
|
return intAttr.getInt();
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
std::optional<int64_t> matchConstantIndexValue(OpFoldResult value) {
|
|
if (auto attr = dyn_cast<Attribute>(value))
|
|
if (auto intAttr = dyn_cast<IntegerAttr>(attr); intAttr && intAttr.getType().isIndex())
|
|
return intAttr.getInt();
|
|
if (auto operand = dyn_cast<Value>(value))
|
|
return matchConstantIndexValue(operand);
|
|
return std::nullopt;
|
|
}
|
|
|
|
} // namespace onnx_mlir
|