better MaterializeMergeSchedule.cpp with %lane indexed batch computes

support for tensors of index values
This commit is contained in:
NiccoloN
2026-05-22 21:52:28 +02:00
parent 495186503c
commit c77ffa9c56
20 changed files with 398 additions and 300 deletions
+1 -1
View File
@@ -264,7 +264,7 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
byteOffset += linearizeIndex(offsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
+25
View File
@@ -1,4 +1,5 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
@@ -35,6 +36,30 @@ int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
return numElements;
}
bool hasByteSizedElementType(mlir::Type elementType) {
if (mlir::isa<mlir::IndexType>(elementType))
return true;
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
return intType.getWidth() > 0 && intType.getWidth() % 8 == 0;
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0;
return false;
}
size_t getElementTypeSizeInBytes(mlir::Type elementType) {
if (mlir::isa<mlir::IndexType>(elementType))
return mlir::IndexType::kInternalStorageBitWidth / 8;
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
return static_cast<size_t>(intType.getWidth() / 8);
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
return static_cast<size_t>(floatType.getWidth() / 8);
llvm_unreachable("expected byte-sized integer, float, or index element type");
}
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) {
return static_cast<size_t>(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType());
}
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
+11
View File
@@ -1,8 +1,13 @@
#pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
@@ -14,6 +19,12 @@ int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t>
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool hasByteSizedElementType(mlir::Type elementType);
size_t getElementTypeSizeInBytes(mlir::Type elementType);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,