better MaterializeMergeSchedule.cpp with %lane indexed batch computes
support for tensors of index values
This commit is contained in:
@@ -427,11 +427,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto inputArg = computeOp.getInputArgument(aHSliceId);
|
||||
if (!weightArg || !inputArg)
|
||||
return failure();
|
||||
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter,
|
||||
gemmLoc,
|
||||
currOutHSliceType,
|
||||
*weightArg,
|
||||
*inputArg));
|
||||
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, *weightArg, *inputArg));
|
||||
}
|
||||
if (vmmOutputs.empty()) {
|
||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||
|
||||
@@ -121,7 +121,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||
tensor::ParallelInsertSliceOp insertSlice,
|
||||
ShapedType destinationType,
|
||||
IRMapping& mapper) {
|
||||
int64_t elementBytes = destinationType.getElementTypeBitWidth() / 8;
|
||||
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
||||
SmallVector<int64_t> strides(destinationType.getRank(), 1);
|
||||
ArrayRef<int64_t> shape = destinationType.getShape();
|
||||
for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim)
|
||||
|
||||
@@ -55,10 +55,6 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
|
||||
return returnValue;
|
||||
}
|
||||
|
||||
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
|
||||
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
||||
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||
}
|
||||
|
||||
@@ -20,8 +20,6 @@ namespace onnx_mlir {
|
||||
*/
|
||||
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
|
||||
|
||||
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
||||
|
||||
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||
|
||||
template <class T>
|
||||
|
||||
@@ -433,7 +433,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
||||
markOpToRemove(op);
|
||||
|
||||
auto storedType = cast<ShapedType>(currentStoredValue.getType());
|
||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||
size_t elementSize = getElementTypeSizeInBytes(storedType.getElementType());
|
||||
if (auto storedOp = currentStoredValue.getDefiningOp())
|
||||
rewriter.setInsertionPointAfter(storedOp);
|
||||
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||
@@ -455,7 +455,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
||||
|
||||
if (isa<func::ReturnOp>(resultUser)) {
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
||||
emitHostCopy(rewriter,
|
||||
@@ -471,7 +471,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low
|
||||
}
|
||||
|
||||
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
||||
size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8;
|
||||
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
||||
for (Operation* concatOp : concatReturnUse->concatChain)
|
||||
markOpToRemove(concatOp);
|
||||
|
||||
|
||||
@@ -325,9 +325,9 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(
|
||||
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
||||
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
||||
Type elementType = tensorType.getElementType();
|
||||
if (!elementType.isIntOrFloat())
|
||||
if (!hasByteSizedElementType(elementType))
|
||||
return;
|
||||
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
||||
size_t elementByteSize = getElementTypeSizeInBytes(elementType);
|
||||
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
||||
|
||||
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
||||
|
||||
Reference in New Issue
Block a user