better MaterializeMergeSchedule.cpp with %lane indexed batch computes

support for tensors of index values
This commit is contained in:
NiccoloN
2026-05-22 21:52:28 +02:00
parent 495186503c
commit c77ffa9c56
20 changed files with 398 additions and 300 deletions
@@ -427,11 +427,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
auto inputArg = computeOp.getInputArgument(aHSliceId);
if (!weightArg || !inputArg)
return failure();
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter,
gemmLoc,
currOutHSliceType,
*weightArg,
*inputArg));
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, *weightArg, *inputArg));
}
if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
@@ -121,7 +121,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
tensor::ParallelInsertSliceOp insertSlice,
ShapedType destinationType,
IRMapping& mapper) {
int64_t elementBytes = destinationType.getElementTypeBitWidth() / 8;
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
SmallVector<int64_t> strides(destinationType.getRank(), 1);
ArrayRef<int64_t> shape = destinationType.getShape();
for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim)
@@ -55,10 +55,6 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
return returnValue;
}
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
}
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
}
@@ -20,8 +20,6 @@ namespace onnx_mlir {
*/
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
template <class T>
@@ -433,7 +433,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
markOpToRemove(op);
auto storedType = cast<ShapedType>(currentStoredValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
size_t elementSize = getElementTypeSizeInBytes(storedType.getElementType());
if (auto storedOp = currentStoredValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp);
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
@@ -455,7 +455,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8;
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy(rewriter,
@@ -471,7 +471,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
}
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8;
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
for (Operation* concatOp : concatReturnUse->concatChain)
markOpToRemove(concatOp);
@@ -325,9 +325,9 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(inputTensor.getType());
Type elementType = tensorType.getElementType();
if (!elementType.isIntOrFloat())
if (!hasByteSizedElementType(elementType))
return;
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
size_t elementByteSize = getElementTypeSizeInBytes(elementType);
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);