d09e76c8f9
Validate Operations / validate-operations (push) Has been cancelled
fix reshape lowering add support for grouped-convolution lowering quieter verifier with capped error messages
257 lines
10 KiB
C++
257 lines
10 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallBitVector.h"
|
|
#include "llvm/ADT/SmallPtrSet.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
|
|
namespace {
|
|
|
|
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
|
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
|
}
|
|
|
|
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
|
|
return llvm::all_of(extractOp.getIndices(),
|
|
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
|
|
}
|
|
|
|
static bool isStaticTensorResult(Operation* op) {
|
|
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
|
auto shapedType = dyn_cast<ShapedType>(type);
|
|
return shapedType && shapedType.hasStaticShape();
|
|
});
|
|
}
|
|
|
|
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
|
SmallVector<int64_t> strides(shape.size(), 1);
|
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
|
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
|
return strides;
|
|
}
|
|
|
|
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
|
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
|
if (!tensorType)
|
|
return failure();
|
|
|
|
int64_t rank = tensorType.getRank();
|
|
if (static_cast<int64_t>(perms.size()) != rank)
|
|
return failure();
|
|
|
|
llvm::SmallBitVector seen(rank);
|
|
SmallVector<int64_t> transposedShape;
|
|
transposedShape.reserve(rank);
|
|
for (int64_t perm : perms) {
|
|
if (perm < 0 || perm >= rank || seen.test(perm))
|
|
return failure();
|
|
seen.set(perm);
|
|
transposedShape.push_back(tensorType.getShape()[perm]);
|
|
}
|
|
|
|
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
|
|
if (denseAttr.isSplat())
|
|
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
|
|
|
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
|
SmallVector<Attribute> transposedValues(originalValues.size());
|
|
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
|
|
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
|
|
SmallVector<int64_t> originalIndices(rank);
|
|
|
|
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
|
int64_t remaining = static_cast<int64_t>(linearIndex);
|
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
|
originalIndices[dim] = remaining / originalStrides[dim];
|
|
remaining %= originalStrides[dim];
|
|
}
|
|
|
|
int64_t transposedLinearIndex = 0;
|
|
for (int64_t dim = 0; dim < rank; ++dim)
|
|
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
|
|
|
|
transposedValues[transposedLinearIndex] = value;
|
|
}
|
|
|
|
return DenseElementsAttr::get(transposedType, transposedValues);
|
|
}
|
|
|
|
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
|
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
|
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
|
|
return failure();
|
|
|
|
if (denseAttr.isSplat())
|
|
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
|
|
|
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
|
|
return DenseElementsAttr::get(resultType, values);
|
|
}
|
|
|
|
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
|
|
tensor::ExtractSliceOp extractSliceOp) {
|
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
|
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
|
|
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
|
|
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
|
|
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
|
|
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|
|
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|
|
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
|
|
return failure();
|
|
|
|
if (denseAttr.isSplat())
|
|
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
|
|
|
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
|
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
|
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
|
|
SmallVector<Attribute> resultValues;
|
|
resultValues.reserve(resultType.getNumElements());
|
|
|
|
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
|
|
int64_t remaining = linearIndex;
|
|
int64_t sourceLinearIndex = 0;
|
|
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
|
|
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
|
|
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
|
|
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
|
|
}
|
|
resultValues.push_back(sourceValues[sourceLinearIndex]);
|
|
}
|
|
|
|
return DenseElementsAttr::get(resultType, resultValues);
|
|
}
|
|
|
|
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
|
|
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
|
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
|
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
|
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
|
return nullptr;
|
|
}
|
|
|
|
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
|
auto* definingOp = value.getDefiningOp();
|
|
if (!definingOp || !visited.insert(definingOp).second)
|
|
return nullptr;
|
|
|
|
// Rebuild dense attributes through view-only host-foldable chains so later
|
|
// lowering stages can still recognize grouped/sliced constants.
|
|
if (auto denseAttr = getDirectDenseConstantAttr(value))
|
|
return denseAttr;
|
|
|
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
|
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
|
|
if (!inputAttr)
|
|
return nullptr;
|
|
|
|
SmallVector<int64_t> perm;
|
|
perm.reserve(transposeOp.getPermAttr().size());
|
|
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
|
|
perm.push_back(attr.getInt());
|
|
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
|
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
|
}
|
|
|
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
|
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
|
if (!inputAttr)
|
|
return nullptr;
|
|
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
|
|
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
|
}
|
|
|
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
|
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
|
|
if (!inputAttr)
|
|
return nullptr;
|
|
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
|
|
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
|
}
|
|
|
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
|
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
|
|
if (!inputAttr)
|
|
return nullptr;
|
|
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
|
|
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
|
if (!op || !visited.insert(op).second)
|
|
return false;
|
|
|
|
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
|
return true;
|
|
|
|
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
|
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
|
|
|
|
if (!isStaticTensorResult(op))
|
|
return false;
|
|
|
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
|
return isHostFoldableValue(transposeOp.getData());
|
|
|
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
|
return isHostFoldableValue(collapseShapeOp.getSrc());
|
|
|
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
|
|
return isHostFoldableValue(expandShapeOp.getSrc());
|
|
|
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
|
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
|
|
|
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
|
return isHostFoldableValue(splatOp.getInput());
|
|
|
|
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
|
return isHostFoldableValue(extractRowsOp.getInput());
|
|
|
|
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
|
|
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
|
|
|
|
return false;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool isHostFoldableValue(Value value) {
|
|
auto* definingOp = value.getDefiningOp();
|
|
if (!definingOp)
|
|
return false;
|
|
|
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
|
return isHostFoldableOpImpl(definingOp, visited);
|
|
}
|
|
|
|
bool isHostFoldableOp(Operation* op) {
|
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
|
return isHostFoldableOpImpl(op, visited);
|
|
}
|
|
|
|
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
|
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
|
return getHostFoldableDenseElementsAttrImpl(value, visited);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|