Bose
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-06-26 17:45:27 +02:00
parent 984f362623
commit 78e97f9fd8
23 changed files with 513 additions and 17489 deletions
@@ -163,6 +163,38 @@ Value extractAxisSlice(
.getResult();
}
Value extractStaticSliceOrIdentity(RewriterBase& rewriter,
Location loc,
Value source,
RankedTensorType resultType,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
auto sourceType = cast<RankedTensorType>(source.getType());
size_t rank = static_cast<size_t>(sourceType.getRank());
bool isIdentitySlice =
sourceType == resultType && sourceType.hasStaticShape() && offsets.size() == rank && sizes.size() == rank
&& strides.size() == rank;
if (isIdentitySlice) {
ArrayRef<int64_t> sourceShape = sourceType.getShape();
for (auto [dim, offset, size, stride] : llvm::zip_equal(sourceShape, offsets, sizes, strides)) {
std::optional<int64_t> staticOffset = mlir::getConstantIntValue(offset);
std::optional<int64_t> staticSize = mlir::getConstantIntValue(size);
std::optional<int64_t> staticStride = mlir::getConstantIntValue(stride);
if (!staticOffset || !staticSize || !staticStride || *staticOffset != 0 || *staticSize != dim || *staticStride != 1) {
isIdentitySlice = false;
break;
}
}
}
if (isIdentitySlice)
return source;
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
}
Value insertStaticSlice(
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
auto sourceType = cast<RankedTensorType>(source.getType());
@@ -105,6 +105,14 @@ llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPer
mlir::Value extractAxisSlice(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
mlir::Value extractStaticSliceOrIdentity(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::RankedTensorType resultType,
llvm::ArrayRef<mlir::OpFoldResult> offsets,
llvm::ArrayRef<mlir::OpFoldResult> sizes,
llvm::ArrayRef<mlir::OpFoldResult> strides);
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,