259 lines
9.7 KiB
C++
259 lines
9.7 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
|
|
|
namespace onnx_mlir {
|
|
|
|
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) {
|
|
if (!moduleOp || !getGlobalOp)
|
|
return {};
|
|
return moduleOp.lookupSymbol<mlir::memref::GlobalOp>(getGlobalOp.getName());
|
|
}
|
|
|
|
namespace {
|
|
|
|
mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
|
if (!knowledge)
|
|
return value;
|
|
|
|
auto iter = knowledge->aliases.find(value);
|
|
while (iter != knowledge->aliases.end()) {
|
|
value = iter->second;
|
|
iter = knowledge->aliases.find(value);
|
|
}
|
|
return value;
|
|
}
|
|
|
|
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
|
value = resolveAlias(value, knowledge);
|
|
|
|
if (mlir::isa<mlir::BlockArgument>(value))
|
|
return value;
|
|
|
|
mlir::Operation* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return value;
|
|
|
|
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
|
if (auto result = mlir::dyn_cast<mlir::OpResult>(value))
|
|
if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
|
|
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
|
}
|
|
|
|
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
|
|
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
|
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
|
|
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
|
|
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp))
|
|
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
|
|
|
|
return value;
|
|
}
|
|
|
|
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
|
|
|
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
|
value = resolveAlias(value, knowledge);
|
|
|
|
if (knowledge) {
|
|
auto iter = knowledge->indexValues.find(value);
|
|
if (iter != knowledge->indexValues.end())
|
|
return iter->second;
|
|
}
|
|
|
|
auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>();
|
|
if (constantOp) {
|
|
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue()))
|
|
return integerAttr.getInt();
|
|
}
|
|
|
|
mlir::Operation* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return mlir::failure();
|
|
|
|
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
|
|
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
|
|
|
|
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp)) {
|
|
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
|
|
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
|
|
if (failed(lhs) || failed(rhs))
|
|
return mlir::failure();
|
|
return *lhs + *rhs;
|
|
}
|
|
|
|
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp)) {
|
|
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
|
|
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
|
|
if (failed(lhs) || failed(rhs))
|
|
return mlir::failure();
|
|
return *lhs - *rhs;
|
|
}
|
|
|
|
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp)) {
|
|
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
|
|
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
|
|
if (failed(lhs) || failed(rhs))
|
|
return mlir::failure();
|
|
return *lhs * *rhs;
|
|
}
|
|
|
|
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp)) {
|
|
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
|
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
|
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
|
return mlir::failure();
|
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
|
}
|
|
|
|
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
|
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
|
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
|
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
|
return mlir::failure();
|
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
|
}
|
|
|
|
return mlir::failure();
|
|
}
|
|
|
|
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
|
|
if (auto attr = mlir::dyn_cast<mlir::Attribute>(ofr)) {
|
|
auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
|
|
if (!integerAttr)
|
|
return mlir::failure();
|
|
return integerAttr.getInt();
|
|
}
|
|
|
|
return resolveIndexValueImpl(mlir::cast<mlir::Value>(ofr), knowledge);
|
|
}
|
|
|
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Value value,
|
|
const StaticValueKnowledge* knowledge) {
|
|
int64_t byteOffset = 0;
|
|
value = resolveAlias(value, knowledge);
|
|
|
|
while (true) {
|
|
if (mlir::isa<mlir::BlockArgument>(value))
|
|
return ResolvedContiguousAddress {value, byteOffset};
|
|
|
|
mlir::Operation* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return mlir::failure();
|
|
|
|
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
|
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
|
|
if (!tiedOperand)
|
|
return mlir::failure();
|
|
value = resolveAlias(tiedOperand->get(), knowledge);
|
|
continue;
|
|
}
|
|
|
|
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
|
|
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
|
if (!result)
|
|
return mlir::failure();
|
|
|
|
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
|
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
|
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
|
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
|
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
|
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
|
continue;
|
|
}
|
|
}
|
|
|
|
value = yieldedValue;
|
|
continue;
|
|
}
|
|
|
|
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
|
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
|
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
|
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
|
return mlir::failure();
|
|
|
|
llvm::SmallVector<int64_t> offsets;
|
|
llvm::SmallVector<int64_t> sizes;
|
|
llvm::SmallVector<int64_t> strides;
|
|
offsets.reserve(subviewOp.getMixedOffsets().size());
|
|
sizes.reserve(subviewOp.getMixedSizes().size());
|
|
strides.reserve(subviewOp.getMixedStrides().size());
|
|
|
|
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
|
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
|
|
if (failed(resolvedOffset))
|
|
return mlir::failure();
|
|
offsets.push_back(*resolvedOffset);
|
|
}
|
|
|
|
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
|
auto resolvedSize = resolveOpFoldResult(size, knowledge);
|
|
if (failed(resolvedSize))
|
|
return mlir::failure();
|
|
sizes.push_back(*resolvedSize);
|
|
}
|
|
|
|
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
|
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
|
|
if (failed(resolvedStride))
|
|
return mlir::failure();
|
|
strides.push_back(*resolvedStride);
|
|
}
|
|
|
|
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
|
return mlir::failure();
|
|
|
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
|
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
|
value = resolveAlias(subviewOp.getSource(), knowledge);
|
|
continue;
|
|
}
|
|
|
|
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
|
|
value = resolveAlias(castOp.getSource(), knowledge);
|
|
continue;
|
|
}
|
|
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
|
|
value = resolveAlias(collapseOp.getSrc(), knowledge);
|
|
continue;
|
|
}
|
|
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
|
|
value = resolveAlias(expandOp.getSrc(), knowledge);
|
|
continue;
|
|
}
|
|
|
|
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
|
|
return ResolvedContiguousAddress {value, byteOffset};
|
|
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
|
|
|
|
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
|
return resolveIndexValueImpl(value, &knowledge);
|
|
}
|
|
|
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
|
|
return resolveContiguousAddressImpl(value, nullptr);
|
|
}
|
|
|
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
|
const StaticValueKnowledge& knowledge) {
|
|
return resolveContiguousAddressImpl(value, &knowledge);
|
|
}
|
|
|
|
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
|
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|