#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Casting.h" #include #include #include #include "ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { SmallVector sliceTensor( const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) { ArrayRef shape = getTensorShape(tensorToSlice); assert("Invalid axis" && axis < shape.size()); SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); SmallVector sizes; sizes.reserve(shape.size()); for (const auto size : shape) sizes.push_back(rewriter.getIndexAttr(size)); sizes[axis] = rewriter.getIndexAttr(sliceSize); long length = shape[axis]; auto [numSlices, lastSliceSize] = ceilIntegerDivideWithRemainder(length, sliceSize); SmallVector slices; slices.reserve(numSlices); for (int64_t i = 0; i < numSlices; i++) { offsets[axis] = rewriter.getIndexAttr(i * sliceSize); if (i == numSlices - 1 && lastSliceSize != 0) sizes[axis] = rewriter.getIndexAttr(lastSliceSize); Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides); slices.push_back(slice); } return slices; } SmallVector sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) { ArrayRef shape = getTensorShape(vectorToSlice); assert("Not a vector" && isVectorShape(shape)); size_t axis = shape[0] != 1 ? 0 : 1; return sliceTensor(vectorToSlice, axis, sliceSize, rewriter, loc); } DenseMap> sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) { SmallVector slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc); DenseMap> slicesPerCore; for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) { size_t coreId = sliceId / crossbarCountInCore; slicesPerCore[coreId].push_back(slices[sliceId]); } return slicesPerCore; } DenseMap>> tileMatrix( Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) { assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile))); DenseMap>> tiles; SmallVector hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc); size_t numHSlices = hSlices.size(); for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) { Value hSlice = hSlices[hSliceId]; SmallVector vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc); for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) { size_t coreId = vSliceId / crossbarCountInCore; Value vSlice = vSlices[vSliceId]; tiles[hSliceId][coreId].push_back(vSlice); } } return tiles; } tensor::SplatOp broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) { auto oldType = cast(scalarToBroadcast.getType()); Type elementType = oldType.getElementType(); int64_t shape[2] = {1, length}; Type type = oldType.cloneWith(ArrayRef(shape), elementType); auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); SmallVector index(oldType.getRank(), zero); auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult(); return tensor::SplatOp::create(rewriter, loc, type, elementValue); } Value sumTensors(ArrayRef tensors, ConversionPatternRewriter& rewriter) { if (tensors.size() == 1) return tensors[0]; SmallVector tensors1 = {tensors.begin(), tensors.end()}; SmallVector tensors2; tensors2.reserve(tensors.size() / 2); auto* currTensors = &tensors1; auto* nextTensors = &tensors2; while (currTensors->size() > 1) { for (size_t i = 0; i < currTensors->size() - 1; i += 2) { Value a = (*currTensors)[i]; Value b = (*currTensors)[i + 1]; rewriter.setInsertionPointAfterValue(b); auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b); nextTensors->push_back(addedValue); } if (currTensors->size() % 2 == 1) nextTensors->push_back(currTensors->back()); std::swap(currTensors, nextTensors); nextTensors->clear(); } assert(currTensors->size() == 1 && "Expected a single input at this point."); return (*currTensors)[0]; } Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) { switch (mapOp) { case MapOperations::None: assert(false && "Invalid map operation during map operation creation."); case MapOperations::ONNXSoftmaxOp: return ONNXSoftmaxOp::create(rewriter, input.getLoc(), input.getType(), input); case MapOperations::ONNXReluOp: return ONNXReluOp::create(rewriter, input.getLoc(), input.getType(), input); case MapOperations::ONNXLeakyReluOp: return ONNXLeakyReluOp::create(rewriter, input.getLoc(), input.getType(), input); case MapOperations::ONNXExpOp: return ONNXExpOp::create(rewriter, input.getLoc(), input.getType(), input); } } void unpackOptionalPairVector(std::optional valuesArray, size_t& value1, size_t& value2) { if (auto unpackedStrides = valuesArray) { value1 = mlir::cast(unpackedStrides->getValue()[0]).getInt(); value2 = mlir::cast(unpackedStrides->getValue()[1]).getInt(); } else { value1 = 1; value2 = 1; } } std::optional unpackOptionalPadsVector(std::optional valuesArray, size_t& pad_x, size_t& pad_y) { if (valuesArray.has_value()) { auto pads = mlir::ArrayAttr(*valuesArray); if (pads.size() != 4) return "pads must have 4 elements."; pad_x = cast(pads[2]).getInt(); pad_y = cast(pads[3]).getInt(); } else { // Default padding is 0 unless specified otherwise. // https://onnx.ai/onnx/operators/onnx__Conv.html pad_x = pad_y = 0; } return std::nullopt; } void tileImageTensorByChannel(Value imageTensor, SmallVector>>& tiles, size_t tileSize, ConversionPatternRewriter& rewriter) { ShapedType imageShape = mlir::cast(imageTensor.getType()); size_t input_h = getImageHeight(imageShape); size_t input_w = getImageWidth(imageShape); size_t tileCount = ceilIntegerDivide(getImageChannel(imageShape), tileSize); size_t tileRest = getImageChannel(imageShape) % tileSize; SmallVector strides(4, rewriter.getIndexAttr(1)); SmallVector offsets(4, rewriter.getIndexAttr(0)); SmallVector sizes = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(tileSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Location loc = imageTensor.getLoc(); for (size_t i = 0; i < tileCount; i++) { if (i == tileCount - 1 && tileRest != 0) sizes[1] = rewriter.getIndexAttr(tileRest); for (size_t x = 0; x < input_w; x++) { for (size_t y = 0; y < input_h; y++) { offsets[1] = rewriter.getIndexAttr(i * tileSize); offsets[2] = rewriter.getIndexAttr(x); offsets[3] = rewriter.getIndexAttr(y); tiles[i][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, imageTensor, offsets, sizes, strides); } } } } Value createImgConcatOp(SmallVector>>& outputTiles, ConversionPatternRewriter& rewriter, Location& loc, Type outputType) { // Populate the outputTiles for the concat in the given order: // 1. Start top left pixel // 2. Continue on its right pixel till the end of the row // 3. Restart on the next row size_t outputTileCount = outputTiles.size(); size_t output_w = outputTiles[0].size(); size_t output_h = outputTiles[0][0].size(); SmallVector tilesToConcat; tilesToConcat.reserve(output_h * output_w * outputTileCount * crossbarSize); for (size_t outX = 0; outX < output_h; outX++) for (size_t outY = 0; outY < output_w; outY++) for (size_t outTile = 0; outTile < outputTileCount; outTile++) tilesToConcat.push_back(outputTiles[outTile][outX][outY]); return spatial::SpatImgConcatOp::create(rewriter, loc, outputType, tilesToConcat); } LogicalResult verifyWithinBoundsAndPaddings(size_t input_w, size_t input_h, int inX, int inY, size_t pad_x, size_t pad_y) { if (inX < 0) { assert((size_t) (-inX) <= pad_x && "verifyWithinBoundsAndPaddings: Negative x value out of padding"); return failure(); } if (inY < 0) { assert((size_t) (-inY) <= pad_y && "verifyWithinBoundsAndPaddings: Negative y value out of padding"); return failure(); } if ((size_t) inX >= input_w || (size_t) inY >= input_h) { assert((size_t) inX < input_w + pad_x && "verifyWithinBoundsAndPaddings: Positive x out of bounds"); assert((size_t) inY < input_h + pad_y && "verifyWithinBoundsAndPaddings: Positive y out of bounds"); return failure(); } return success(); } Value createExtractSliceImg(Value valToSlice, size_t x, size_t y, size_t t, size_t channelTileCount, size_t channelTileRest, size_t input_w, size_t input_h, PatternRewriter& rewriter) { SmallVector strides(4, rewriter.getIndexAttr(1)); SmallVector offsets(4, rewriter.getIndexAttr(0)); SmallVector sizes = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; if (t == channelTileCount - 1 && channelTileRest != 0) sizes[1] = rewriter.getIndexAttr(channelTileRest); offsets[1] = rewriter.getIndexAttr(t * crossbarSize); offsets[2] = rewriter.getIndexAttr(x); offsets[3] = rewriter.getIndexAttr(y); return tensor::ExtractSliceOp::create(rewriter, valToSlice.getLoc(), valToSlice, offsets, sizes, strides); } Value indexImgValue(Value v, size_t x, size_t y, size_t t, size_t channelTileCount, size_t channelTileRest, size_t input_w, size_t input_h, ConversionPatternRewriter& rewriter) { auto newV = rewriter.getRemappedValue(v); if (newV) v = newV; if (!v.getDefiningOp()) return createExtractSliceImg(v, x, y, t, channelTileCount, channelTileRest, input_w, input_h, rewriter); if (auto computeOp = v.getDefiningOp()) { // We found the computeOp that produces the tile we want, just return this // value. // TODO: Should we assert that x,y,t are zero? assert(x == 0 && y == 0 && t == 0 && "indexImgValue: WeightedComputeOp tile indeces should be zero"); return v; } if (auto receiveOp = v.getDefiningOp()) { // This is a receiveOp, just return its value which will be resolved later assert(x == 0 && y == 0 && t == 0 && "indexImgValue: receiveOp tile indeces should be zero"); return v; } if (auto imgConcatOp = v.getDefiningOp()) { auto imgConcatInput = imgConcatOp.getInputTile(x, y, t); // TODO: Is this correct? // Above we already index exactly the tile we want, so `x=y=t=0` in // recursive call return indexImgValue(imgConcatInput, 0, 0, 0, channelTileCount, channelTileRest, input_w, input_h, rewriter); } if (auto tensorConcatOp = v.getDefiningOp()) { // This can be recursive. // First, get the input tensors of the tensor.concatOp // Then, find the input tensor that contains the tile we want // Finally, recursive call asking for the tile auto concatAxis = tensorConcatOp.getDim(); assert(concatAxis != 0 && "Expecting to concat on channel/x/y axis"); assert(concatAxis == 1 && "TODO: Make sure this works and makes sense for other axis."); SmallVector indexDims = {1, t * crossbarSize, x, y}; // Find the input tensor that contains the tile we want size_t currentTile = 0; for (auto concatInput : tensorConcatOp.getInputs()) { auto concatInputShape = cast(concatInput.getType()); assert(concatInputShape.getRank() == 4 && "Expecting an image tensor"); auto concatInputSizeOnAxis = concatInputShape.getDimSize(concatAxis); if (currentTile + concatInputSizeOnAxis > indexDims[concatAxis]) { // This input tensor contains the tile we want indexDims[concatAxis] -= currentTile; if (indexDims[1] % crossbarSize != 0) { assert(ignoreConcatError && "TODO: Handle non-tile aligned tensor, or set " "--ignore-concat-error=true"); } return indexImgValue(concatInput, indexDims[2], indexDims[3], indexDims[1] / crossbarSize, channelTileCount, channelTileRest, input_w, input_h, rewriter); } currentTile += concatInputSizeOnAxis; } assert(false && "Could not find the input tensor that contains the tile " "within tensor.ConcatOp"); } v.dump(); assert(false && "indexImgValue: unsupported operation"); } void resolveInputTensorTilesBlockArg(Value wholeInputTensor, SmallVector>>& inputTiles, size_t channelTileCount, size_t channelTileRest, size_t input_w, size_t input_h, PatternRewriter& rewriter) { SmallVector strides(4, rewriter.getIndexAttr(1)); SmallVector offsets(4, rewriter.getIndexAttr(0)); SmallVector sizes = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Location loc = wholeInputTensor.getLoc(); for (size_t t = 0; t < channelTileCount; t++) { if (t == channelTileCount - 1 && channelTileRest != 0) sizes[1] = rewriter.getIndexAttr(channelTileRest); for (size_t x = 0; x < input_w; x++) { for (size_t y = 0; y < input_h; y++) { offsets[1] = rewriter.getIndexAttr(t * crossbarSize); offsets[2] = rewriter.getIndexAttr(x); offsets[3] = rewriter.getIndexAttr(y); inputTiles[t][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, wholeInputTensor, offsets, sizes, strides); } } } } std::optional resolveImgInputTiles(Value wholeInputTensor, SmallVector>>& inputTiles, size_t channelTileCount, size_t channelTileRest, size_t input_w, size_t input_h, ConversionPatternRewriter& rewriter) { for (size_t t = 0; t < channelTileCount; t++) { for (size_t x = 0; x < input_w; x++) { for (size_t y = 0; y < input_h; y++) { inputTiles[t][x][y] = indexImgValue(wholeInputTensor, x, y, t, channelTileCount, channelTileRest, input_w, input_h, rewriter); } } } return std::nullopt; } LogicalResult handleFlattenLikeOp(SmallVector>& inputTiles, const size_t inputTilesCount, const size_t lastInputTileDimension, TensorType inputShape, TensorType outputShape, Value reshapeInput, ConversionPatternRewriter& rewriter) { // Only support reshape between an image and a vector (i.e. flatten) if (inputShape.getRank() != 4 || outputShape.getRank() != 2) { return rewriter.notifyMatchFailure(reshapeInput.getDefiningOp(), "resolveVecInputTiles only supports reshapes from 4D to 2D tensors"); } /* * From a 4D tensor to a 2D tensor */ auto N = inputShape.getDimSize(0); auto C = inputShape.getDimSize(1); auto H = inputShape.getDimSize(2); auto W = inputShape.getDimSize(3); assert(N == 1 && "Only support N = 1 for image tensors"); for (size_t i = 0; i < inputTilesCount; i++) { auto c = (i / (H * W)) % C; // TODO: Is this correct? Or should I invert h and w? auto w = (i / H) % W; auto h = i % H; Value curTile = indexImgValue(reshapeInput, w, h, c, inputTilesCount, lastInputTileDimension, W, H, rewriter); // Assert the shape of the tile, and reshape it auto curTileShape = cast(curTile.getType()); assert(curTileShape.getRank() == 4 && "We just reshaped an image tensor, why rank != 4?"); assert(curTileShape.getDimSize(0) == 1 && "We just reshaped an image tensor with N = 1, why is it now != 1?"); assert(curTileShape.getDimSize(2) == 1 && "We should have just looked up a single pixel why W != 1?"); assert(curTileShape.getDimSize(3) == 1 && "We should have just looked up a single pixel why H != 1?"); // Reshape this pixel tensor into a vector, for compatibility with the // rest SmallVector newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)}; auto shapeType = RankedTensorType::get({static_cast(newShapeVals.size())}, rewriter.getI64Type()); Value shapeTensor = arith::ConstantOp::create(rewriter, reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals)); auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType()); auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor); size_t coreIndex = i / crossbarCountInCore; inputTiles[coreIndex].push_back(reshapedCurTile); } return success(); } std::pair kernel_get_start_and_end( int64_t out_pos, int64_t input_width, int64_t krn_width, int64_t stride, int64_t dilation, int64_t pad) { int64_t firstValid = std::ceil(static_cast(pad) / dilation) * dilation - pad; int64_t start = std::max(firstValid, out_pos * stride - pad); int64_t end = std::min(input_width, out_pos * stride + (krn_width - 1) * dilation + 1 - pad); assert(start >= 0 && "Start position must be non-negative."); assert(end >= 0 && "End position must be non-negative."); return std::make_pair(start, end); } void incrementWeightedComputeInputsSegmentSize(spatial::SpatWeightedCompute wcomputeOp, int increment) { auto oldSegmentSizes = wcomputeOp->getAttrOfType(wcomputeOp.getOperandSegmentSizesAttrName()); auto newSegmentSizes = DenseI32ArrayAttr::get(wcomputeOp->getContext(), {oldSegmentSizes[0], oldSegmentSizes[1] + increment}); wcomputeOp->setAttr(wcomputeOp.getOperandSegmentSizesAttrName(), newSegmentSizes); } int getResultIndex(Operation* op, Value v) { int resultNumber = -1; for (auto result : op->getResults()) { if (result == v) { resultNumber = result.getResultNumber(); break; } } assert(resultNumber >= 0 && "Value not found in given operation's results."); return resultNumber; } }; // namespace onnx_mlir