#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) { return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; }); } static bool hasConstantIndices(tensor::ExtractOp extractOp) { return llvm::all_of(extractOp.getIndices(), [](Value index) { return isa_and_nonnull(index.getDefiningOp()); }); } static bool isStaticTensorResult(Operation* op) { return llvm::all_of(op->getResultTypes(), [](Type type) { auto shapedType = dyn_cast(type); return shapedType && shapedType.hasStaticShape(); }); } static SmallVector computeRowMajorStrides(ArrayRef shape) { SmallVector strides(shape.size(), 1); for (int64_t dim = static_cast(shape.size()) - 2; dim >= 0; --dim) strides[dim] = strides[dim + 1] * shape[dim + 1]; return strides; } static FailureOr transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef perms) { auto tensorType = dyn_cast(denseAttr.getType()); if (!tensorType) return failure(); int64_t rank = tensorType.getRank(); if (static_cast(perms.size()) != rank) return failure(); llvm::SmallBitVector seen(rank); SmallVector transposedShape; transposedShape.reserve(rank); for (int64_t perm : perms) { if (perm < 0 || perm >= rank || seen.test(perm)) return failure(); seen.set(perm); transposedShape.push_back(tensorType.getShape()[perm]); } auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding()); if (denseAttr.isSplat()) return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue()); SmallVector originalValues(denseAttr.getValues()); SmallVector transposedValues(originalValues.size()); SmallVector originalStrides = computeRowMajorStrides(tensorType.getShape()); SmallVector transposedStrides = computeRowMajorStrides(transposedShape); SmallVector originalIndices(rank); for (auto [linearIndex, value] : llvm::enumerate(originalValues)) { int64_t remaining = static_cast(linearIndex); for (int64_t dim = 0; dim < rank; ++dim) { originalIndices[dim] = remaining / originalStrides[dim]; remaining %= originalStrides[dim]; } int64_t transposedLinearIndex = 0; for (int64_t dim = 0; dim < rank; ++dim) transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim]; transposedValues[transposedLinearIndex] = value; } return DenseElementsAttr::get(transposedType, transposedValues); } static FailureOr reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) { auto sourceType = dyn_cast(denseAttr.getType()); if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements()) return failure(); if (denseAttr.isSplat()) return DenseElementsAttr::get(resultType, denseAttr.getSplatValue()); SmallVector values(denseAttr.getValues()); return DenseElementsAttr::get(resultType, values); } static FailureOr extractSliceDenseElements(DenseElementsAttr denseAttr, tensor::ExtractSliceOp extractSliceOp) { auto sourceType = dyn_cast(denseAttr.getType()); auto resultType = dyn_cast(extractSliceOp.getType()); if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); ArrayRef offsets = extractSliceOp.getStaticOffsets(); ArrayRef sizes = extractSliceOp.getStaticSizes(); ArrayRef strides = extractSliceOp.getStaticStrides(); if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); }) || llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); }) || llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; })) return failure(); if (denseAttr.isSplat()) return DenseElementsAttr::get(resultType, denseAttr.getSplatValue()); SmallVector sourceValues(denseAttr.getValues()); SmallVector sourceStrides = computeRowMajorStrides(sourceType.getShape()); SmallVector resultStrides = computeRowMajorStrides(resultType.getShape()); SmallVector resultValues; resultValues.reserve(resultType.getNumElements()); for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) { int64_t remaining = linearIndex; int64_t sourceLinearIndex = 0; for (int64_t dim = 0; dim < resultType.getRank(); ++dim) { const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim]; remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim]; sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim]; } resultValues.push_back(sourceValues[sourceLinearIndex]); } return DenseElementsAttr::get(resultType, resultValues); } static DenseElementsAttr getDirectDenseConstantAttr(Value value) { if (auto constantOp = value.getDefiningOp()) return dyn_cast(constantOp.getValue()); if (auto constantOp = value.getDefiningOp()) return dyn_cast_or_null(constantOp.getValueAttr()); return nullptr; } static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl& visited) { auto* definingOp = value.getDefiningOp(); if (!definingOp || !visited.insert(definingOp).second) return nullptr; // Rebuild dense attributes through view-only host-foldable chains so later // lowering stages can still recognize grouped/sliced constants. if (auto denseAttr = getDirectDenseConstantAttr(value)) return denseAttr; if (auto transposeOp = dyn_cast(definingOp)) { auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited); if (!inputAttr) return nullptr; SmallVector perm; perm.reserve(transposeOp.getPermAttr().size()); for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange()) perm.push_back(attr.getInt()); auto transposedAttr = transposeDenseElements(inputAttr, perm); return succeeded(transposedAttr) ? *transposedAttr : nullptr; } if (auto collapseShapeOp = dyn_cast(definingOp)) { auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited); if (!inputAttr) return nullptr; auto reshapedAttr = reshapeDenseElements(inputAttr, cast(collapseShapeOp.getType())); return succeeded(reshapedAttr) ? *reshapedAttr : nullptr; } if (auto expandShapeOp = dyn_cast(definingOp)) { auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited); if (!inputAttr) return nullptr; auto reshapedAttr = reshapeDenseElements(inputAttr, cast(expandShapeOp.getType())); return succeeded(reshapedAttr) ? *reshapedAttr : nullptr; } if (auto extractSliceOp = dyn_cast(definingOp)) { auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited); if (!inputAttr) return nullptr; auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp); return succeeded(slicedAttr) ? *slicedAttr : nullptr; } return nullptr; } static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl& visited) { if (!op || !visited.insert(op).second) return false; if (isa(op)) return true; if (auto extractOp = dyn_cast(op)) return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor()); if (!isStaticTensorResult(op)) return false; if (auto transposeOp = dyn_cast(op)) return isHostFoldableValue(transposeOp.getData()); if (auto collapseShapeOp = dyn_cast(op)) return isHostFoldableValue(collapseShapeOp.getSrc()); if (auto expandShapeOp = dyn_cast(op)) return isHostFoldableValue(expandShapeOp.getSrc()); if (auto extractSliceOp = dyn_cast(op)) return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource()); if (auto splatOp = dyn_cast(op)) return isHostFoldableValue(splatOp.getInput()); if (auto extractRowsOp = dyn_cast(op)) return isHostFoldableValue(extractRowsOp.getInput()); if (auto concatOp = dyn_cast(op)) return llvm::all_of(concatOp.getInputs(), isHostFoldableValue); return false; } } // namespace bool isHostFoldableValue(Value value) { auto* definingOp = value.getDefiningOp(); if (!definingOp) return false; llvm::SmallPtrSet visited; return isHostFoldableOpImpl(definingOp, visited); } bool isHostFoldableOp(Operation* op) { llvm::SmallPtrSet visited; return isHostFoldableOpImpl(op, visited); } DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) { llvm::SmallPtrSet visited; return getHostFoldableDenseElementsAttrImpl(value, visited); } } // namespace onnx_mlir