300 lines
12 KiB
C++
300 lines
12 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallBitVector.h"
|
|
#include "llvm/ADT/SmallPtrSet.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include <utility>
|
|
|
|
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
namespace {
|
|
|
|
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
|
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
|
}
|
|
|
|
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
|
|
return llvm::all_of(extractOp.getIndices(), [](Value index) { return matchConstantIndexValue(index).has_value(); });
|
|
}
|
|
|
|
static bool isStaticTensorResult(Operation* op) {
|
|
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
|
auto shapedType = dyn_cast<ShapedType>(type);
|
|
return shapedType && shapedType.hasStaticShape();
|
|
});
|
|
}
|
|
|
|
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
|
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
|
if (!tensorType)
|
|
return failure();
|
|
|
|
int64_t rank = tensorType.getRank();
|
|
if (static_cast<int64_t>(perms.size()) != rank)
|
|
return failure();
|
|
|
|
llvm::SmallBitVector seen(rank);
|
|
SmallVector<int64_t> transposedShape;
|
|
transposedShape.reserve(rank);
|
|
for (int64_t perm : perms) {
|
|
if (perm < 0 || perm >= rank || seen.test(perm))
|
|
return failure();
|
|
seen.set(perm);
|
|
transposedShape.push_back(tensorType.getShape()[perm]);
|
|
}
|
|
|
|
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
|
|
if (denseAttr.isSplat())
|
|
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
|
|
|
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
|
SmallVector<Attribute> transposedValues(originalValues.size());
|
|
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
|
|
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
|
|
SmallVector<int64_t> originalIndices(rank);
|
|
|
|
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
|
int64_t remaining = static_cast<int64_t>(linearIndex);
|
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
|
originalIndices[dim] = remaining / originalStrides[dim];
|
|
remaining %= originalStrides[dim];
|
|
}
|
|
|
|
int64_t transposedLinearIndex = 0;
|
|
for (int64_t dim = 0; dim < rank; ++dim)
|
|
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
|
|
|
|
transposedValues[transposedLinearIndex] = value;
|
|
}
|
|
|
|
return DenseElementsAttr::get(transposedType, transposedValues);
|
|
}
|
|
|
|
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
|
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
|
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
|
|
return failure();
|
|
|
|
if (denseAttr.isSplat())
|
|
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
|
|
|
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
|
|
return DenseElementsAttr::get(resultType, values);
|
|
}
|
|
|
|
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
|
|
tensor::ExtractSliceOp extractSliceOp) {
|
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
|
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
|
|
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
|
|
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
|
|
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
|
|
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|
|
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|
|
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
|
|
return failure();
|
|
|
|
if (denseAttr.isSplat())
|
|
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
|
|
|
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
|
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
|
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
|
|
SmallVector<Attribute> resultValues;
|
|
resultValues.reserve(resultType.getNumElements());
|
|
|
|
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
|
|
int64_t remaining = linearIndex;
|
|
int64_t sourceLinearIndex = 0;
|
|
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
|
|
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
|
|
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
|
|
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
|
|
}
|
|
resultValues.push_back(sourceValues[sourceLinearIndex]);
|
|
}
|
|
|
|
return DenseElementsAttr::get(resultType, resultValues);
|
|
}
|
|
|
|
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
|
|
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
|
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
|
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
|
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
|
return nullptr;
|
|
}
|
|
|
|
static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
|
auto* definingOp = value.getDefiningOp();
|
|
if (!definingOp || !visited.insert(definingOp).second)
|
|
return nullptr;
|
|
|
|
// Rebuild dense attributes through view-only host-foldable chains so later
|
|
// lowering stages can still recognize grouped/sliced constants.
|
|
if (auto denseAttr = getDirectDenseConstantAttr(value))
|
|
return denseAttr;
|
|
|
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
|
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getData(), visited);
|
|
if (!inputAttr)
|
|
return nullptr;
|
|
|
|
SmallVector<int64_t> perm;
|
|
perm.reserve(transposeOp.getPermAttr().size());
|
|
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
|
|
perm.push_back(attr.getInt());
|
|
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
|
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
|
}
|
|
|
|
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
|
|
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getInput(), visited);
|
|
if (!inputAttr)
|
|
return nullptr;
|
|
|
|
SmallVector<int64_t> perm(transposeOp.getPermutation().begin(), transposeOp.getPermutation().end());
|
|
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
|
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
|
}
|
|
|
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
|
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
|
if (!inputAttr)
|
|
return nullptr;
|
|
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
|
|
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
|
}
|
|
|
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
|
auto inputAttr = getHostConstantDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
|
|
if (!inputAttr)
|
|
return nullptr;
|
|
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
|
|
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
|
}
|
|
|
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
|
auto inputAttr = getHostConstantDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
|
|
if (!inputAttr)
|
|
return nullptr;
|
|
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
|
|
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
static std::optional<CompileTimeSource>
|
|
getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) {
|
|
if (!op)
|
|
return std::nullopt;
|
|
|
|
if (!visited.insert(op).second)
|
|
return {
|
|
{op, chainLength}
|
|
};
|
|
|
|
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
|
return {
|
|
{op, chainLength}
|
|
};
|
|
|
|
chainLength += 1;
|
|
|
|
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
|
return hasConstantIndices(extractOp)
|
|
? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength)
|
|
: std::nullopt;
|
|
|
|
if (!isStaticTensorResult(op))
|
|
return std::nullopt;
|
|
|
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
|
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
|
|
|
|
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op))
|
|
return getCompileTimeSourceImpl(transposeOp.getInput().getDefiningOp(), visited, chainLength);
|
|
|
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
|
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
|
|
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
|
|
return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
|
|
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
|
return hasStaticUnitStrides(extractSliceOp)
|
|
? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
|
|
: std::nullopt;
|
|
|
|
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
|
return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength);
|
|
|
|
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
|
return getCompileTimeSourceImpl(extractRowsOp.getInput().getDefiningOp(), visited, chainLength);
|
|
|
|
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
|
std::optional<CompileTimeSource> res = {};
|
|
for (auto operandValue : concatOp.getOperands()) {
|
|
auto partialRes = getCompileTimeSourceImpl(operandValue.getDefiningOp(), visited, chainLength);
|
|
if (!partialRes)
|
|
return std::nullopt;
|
|
|
|
if (!res) {
|
|
res = partialRes;
|
|
continue;
|
|
}
|
|
if (res->chainLength < partialRes->chainLength)
|
|
res = partialRes;
|
|
}
|
|
return res;
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
|
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
|
return getCompileTimeSourceImpl(op, visited);
|
|
}
|
|
|
|
bool isCompileTimeComputable(Value value) {
|
|
auto* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return false;
|
|
|
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
|
return getCompileTimeSourceImpl(definingOp, visited).has_value();
|
|
}
|
|
|
|
bool isCompileTimeOp(Operation* op) {
|
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
|
return getCompileTimeSourceImpl(op, visited).has_value();
|
|
}
|
|
|
|
DenseElementsAttr getHostConstDenseElementsAttr(Value value) {
|
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
|
return getHostConstantDenseElementsAttrImpl(value, visited);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|