@@ -5,6 +5,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
@@ -72,4 +73,117 @@ mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Op
|
||||
rewriter, operation->getLoc(), resultShapedType.getShape(), resultShapedType.getElementType());
|
||||
}
|
||||
|
||||
LogicalResult validateFragmentAssemblyMetadata(spatial::SpatReconciliatorOp reconciliator,
|
||||
int64_t resultRank,
|
||||
size_t operandCount,
|
||||
ArrayRef<int64_t> operandIndices,
|
||||
ArrayRef<int64_t> sourceOffsets,
|
||||
ArrayRef<int64_t> flatOffsets,
|
||||
ArrayRef<int64_t> flatSizes,
|
||||
ArrayRef<int64_t> flatStrides) {
|
||||
if (operandIndices.size() != sourceOffsets.size())
|
||||
return reconciliator.emitOpError("fragment assembly operand index and source offset counts must match");
|
||||
if (flatOffsets.size() != flatSizes.size())
|
||||
return reconciliator.emitOpError("fragment assembly offset and size arrays must have matching lengths");
|
||||
if (flatStrides.size() != flatOffsets.size())
|
||||
return reconciliator.emitOpError("fragment assembly stride and offset arrays must have matching lengths");
|
||||
if (flatOffsets.size() != operandIndices.size() * static_cast<size_t>(resultRank))
|
||||
return reconciliator.emitOpError("fragment assembly metadata must provide one rank-sized offset/size/stride tuple per fragment");
|
||||
|
||||
for (auto [fragmentIndex, operandIndex] : llvm::enumerate(operandIndices)) {
|
||||
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(operandCount))
|
||||
return reconciliator.emitOpError("fragment assembly operand index is out of range");
|
||||
if (sourceOffsets[fragmentIndex] < 0)
|
||||
return reconciliator.emitOpError("fragment assembly source offsets must be nonnegative");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static SmallVector<int64_t, 4> expandFlatElementIndex(int64_t flatIndex, ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t, 4> indices(shape.size(), 0);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 1; dim >= 0; --dim) {
|
||||
indices[dim] = flatIndex % shape[dim];
|
||||
flatIndex /= shape[dim];
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t, 4>>
|
||||
getStaticSliceOffsetsForElementOffset(Operation* anchor,
|
||||
ShapedType sourceType,
|
||||
ArrayRef<int64_t> fragmentShape,
|
||||
int64_t sourceElementOffset,
|
||||
StringRef fieldName) {
|
||||
if (!sourceType.hasStaticShape())
|
||||
return (anchor->emitOpError() << fieldName << " requires a static source shape"), failure();
|
||||
if (sourceElementOffset < 0)
|
||||
return (anchor->emitOpError() << fieldName << " requires a nonnegative source element offset"), failure();
|
||||
if (sourceType.getRank() != static_cast<int64_t>(fragmentShape.size()))
|
||||
return (anchor->emitOpError() << fieldName << " requires fragment rank to match source rank"), failure();
|
||||
|
||||
int64_t sourceElementCount = sourceType.getNumElements();
|
||||
int64_t fragmentElementCount = 1;
|
||||
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
|
||||
if (fragmentShape[dim] < 0)
|
||||
return (anchor->emitOpError() << fieldName << " requires nonnegative fragment sizes"), failure();
|
||||
fragmentElementCount *= fragmentShape[dim];
|
||||
}
|
||||
if (sourceElementOffset + fragmentElementCount > sourceElementCount)
|
||||
return (anchor->emitOpError() << fieldName << " exceeds the source tensor bounds"), failure();
|
||||
|
||||
SmallVector<int64_t, 4> sliceOffsets = expandFlatElementIndex(sourceElementOffset, sourceType.getShape());
|
||||
for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
|
||||
if (sliceOffsets[dim] + fragmentShape[dim] > sourceType.getDimSize(dim))
|
||||
return (anchor->emitOpError() << fieldName << " does not describe a valid unit-stride slice"), failure();
|
||||
}
|
||||
return sliceOffsets;
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
forEachContiguousDestinationChunk(ArrayRef<int64_t> destShape,
|
||||
ArrayRef<int64_t> baseOffsets,
|
||||
ArrayRef<int64_t> sizes,
|
||||
llvm::function_ref<LogicalResult(ArrayRef<int64_t>, int64_t, int64_t)> callback) {
|
||||
int64_t rank = static_cast<int64_t>(sizes.size());
|
||||
int64_t suffixStart = rank - 1;
|
||||
while (suffixStart > 0 && sizes[suffixStart] == destShape[suffixStart])
|
||||
--suffixStart;
|
||||
if (sizes[suffixStart] == destShape[suffixStart] && suffixStart == 0)
|
||||
suffixStart = 0;
|
||||
else
|
||||
++suffixStart;
|
||||
|
||||
int64_t chunkElements = 1;
|
||||
for (int64_t dim = suffixStart; dim < rank; ++dim)
|
||||
chunkElements *= sizes[dim];
|
||||
|
||||
SmallVector<int64_t, 4> prefixExtents(sizes.begin(), sizes.begin() + suffixStart);
|
||||
SmallVector<int64_t, 4> current(prefixExtents.size(), 0);
|
||||
int64_t sourceChunkOrdinal = 0;
|
||||
|
||||
auto visit = [&](auto&& visit, int64_t dim) -> LogicalResult {
|
||||
if (dim == static_cast<int64_t>(prefixExtents.size())) {
|
||||
SmallVector<int64_t, 4> chunkOffsets(baseOffsets.begin(), baseOffsets.end());
|
||||
for (int64_t prefixDim = 0; prefixDim < static_cast<int64_t>(current.size()); ++prefixDim)
|
||||
chunkOffsets[prefixDim] += current[prefixDim];
|
||||
if (failed(callback(chunkOffsets, sourceChunkOrdinal * chunkElements, chunkElements)))
|
||||
return failure();
|
||||
++sourceChunkOrdinal;
|
||||
return success();
|
||||
}
|
||||
|
||||
for (int64_t index = 0; index < prefixExtents[dim]; ++index) {
|
||||
current[dim] = index;
|
||||
if (failed(visit(visit, dim + 1)))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
if (prefixExtents.empty())
|
||||
return callback(baseOffsets, 0, chunkElements);
|
||||
return visit(visit, 0);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user