#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { struct PackSpatialConcatInputsPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spatial::SpatConcatOp concatOp, PatternRewriter& rewriter) const override { if (concatOp.getAxis() != 0 || concatOp.getInputs().empty()) return failure(); SmallVector packedInputs; bool changed = false; for (unsigned index = 0; index < concatOp.getInputs().size();) { Value input = concatOp.getInputs()[index]; if (input.getDefiningOp()) { unsigned endIndex = index + 1; while (endIndex < concatOp.getInputs().size() && concatOp.getInputs()[endIndex].getDefiningOp()) ++endIndex; Value packedInput = createPackedExtractSliceTensor( concatOp.getInputs().slice(index, endIndex - index), rewriter, concatOp.getLoc()); if (packedInput) { packedInputs.push_back(packedInput); changed = true; index = endIndex; continue; } } auto result = dyn_cast(input); if (!result) { packedInputs.push_back(input); ++index; continue; } Operation* owner = result.getOwner(); unsigned startIndex = result.getResultNumber(); unsigned endIndex = index + 1; while (endIndex < concatOp.getInputs().size()) { auto nextResult = dyn_cast(concatOp.getInputs()[endIndex]); if (!nextResult || nextResult.getOwner() != owner || nextResult.getResultNumber() != startIndex + (endIndex - index)) break; ++endIndex; } unsigned count = endIndex - index; Value packedInput; if (auto extractRowsOp = dyn_cast(owner)) packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc()); if (packedInput) { packedInputs.push_back(packedInput); changed = true; } else { for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex) packedInputs.push_back(concatOp.getInputs()[oldIndex]); } index = endIndex; } if (!changed) return failure(); auto outputType = cast(concatOp.getOutput().getType()); auto newConcat = pim::PimConcatOp::create( rewriter, concatOp.getLoc(), concatOp.getOutput().getType(), concatOp.getAxisAttr(), ValueRange(packedInputs), tensor::EmptyOp::create(rewriter, concatOp.getLoc(), outputType.getShape(), outputType.getElementType()) .getResult()); rewriter.replaceOp(concatOp, newConcat.getOutput()); return success(); } }; } // namespace RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) { SmallVector packedShape(elementType.getShape().begin(), elementType.getShape().end()); packedShape[0] *= count; return RankedTensorType::get(packedShape, elementType.getElementType()); } Value extractPackedChunk( Value packedValue, RankedTensorType chunkType, unsigned index, OpBuilder& builder, Location loc) { auto packedType = dyn_cast(packedValue.getType()); if (packedType && packedType == chunkType && index == 0) return packedValue; SmallVector offsets; SmallVector sizes; SmallVector strides; offsets.reserve(chunkType.getRank()); sizes.reserve(chunkType.getRank()); strides.reserve(chunkType.getRank()); offsets.push_back(builder.getIndexAttr(static_cast(index) * chunkType.getDimSize(0))); sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(0))); strides.push_back(builder.getIndexAttr(1)); for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) { offsets.push_back(builder.getIndexAttr(0)); sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(dim))); strides.push_back(builder.getIndexAttr(1)); } return tensor::ExtractSliceOp::create(builder, loc, chunkType, packedValue, offsets, sizes, strides).getResult(); } Value createPackedExtractRowsSlice( spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, OpBuilder& builder, Location loc) { auto rowType = dyn_cast(extractRowsOp.getOutputs()[startIndex].getType()); auto inputType = dyn_cast(extractRowsOp.getInput().getType()); if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0) return {}; int64_t rowsPerValue = rowType.getDimSize(0); if (ShapedType::isDynamic(rowsPerValue)) return {}; auto packedType = getPackedTensorType(rowType, static_cast(count)); SmallVector offsets; SmallVector sizes; SmallVector strides; offsets.reserve(inputType.getRank()); sizes.reserve(inputType.getRank()); strides.reserve(inputType.getRank()); offsets.push_back(builder.getIndexAttr(static_cast(startIndex) * rowsPerValue)); sizes.push_back(builder.getIndexAttr(static_cast(count) * rowsPerValue)); strides.push_back(builder.getIndexAttr(1)); for (int64_t dim = 1; dim < inputType.getRank(); ++dim) { offsets.push_back(builder.getIndexAttr(0)); sizes.push_back(builder.getIndexAttr(inputType.getDimSize(dim))); strides.push_back(builder.getIndexAttr(1)); } bool coversWholeSource = packedType == inputType && startIndex == 0; if (coversWholeSource) return extractRowsOp.getInput(); return tensor::ExtractSliceOp::create(builder, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides) .getResult(); } Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) { if (values.empty()) return {}; if (values.size() == 1) return values.front(); auto firstSliceOp = values.front().getDefiningOp(); if (!firstSliceOp) return {}; auto firstType = dyn_cast(firstSliceOp.getResult().getType()); auto sourceType = dyn_cast(firstSliceOp.getSource().getType()); if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape() || firstType.getRank() == 0) return {}; auto hasStaticValues = [](ArrayRef values) { return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); }); }; if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes()) || !hasStaticValues(firstSliceOp.getStaticStrides())) return {}; ArrayRef firstOffsets = firstSliceOp.getStaticOffsets(); ArrayRef firstSizes = firstSliceOp.getStaticSizes(); ArrayRef firstStrides = firstSliceOp.getStaticStrides(); int64_t rowsPerValue = firstSizes[0]; if (ShapedType::isDynamic(rowsPerValue)) return {}; for (size_t index = 1; index < values.size(); ++index) { auto sliceOp = values[index].getDefiningOp(); if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource() || sliceOp.getResult().getType() != firstSliceOp.getResult().getType() || !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes()) || !hasStaticValues(sliceOp.getStaticStrides())) return {}; if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides) return {}; if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast(index) * rowsPerValue) return {}; for (int64_t dim = 1; dim < firstType.getRank(); ++dim) if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim]) return {}; } auto packedType = getPackedTensorType(firstType, static_cast(values.size())); SmallVector offsets; SmallVector sizes; SmallVector strides; offsets.reserve(firstType.getRank()); sizes.reserve(firstType.getRank()); strides.reserve(firstType.getRank()); offsets.push_back(builder.getIndexAttr(firstOffsets[0])); sizes.push_back(builder.getIndexAttr(rowsPerValue * static_cast(values.size()))); strides.push_back(builder.getIndexAttr(firstStrides[0])); for (int64_t dim = 1; dim < firstType.getRank(); ++dim) { offsets.push_back(builder.getIndexAttr(firstOffsets[dim])); sizes.push_back(builder.getIndexAttr(firstSizes[dim])); strides.push_back(builder.getIndexAttr(firstStrides[dim])); } bool coversWholeSource = packedType == sourceType; for (int64_t dim = 0; coversWholeSource && dim < sourceType.getRank(); ++dim) coversWholeSource = firstOffsets[dim] == 0 && firstStrides[dim] == 1; if (coversWholeSource) return firstSliceOp.getSource(); return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides) .getResult(); } void populateTensorPackingPatterns(RewritePatternSet& patterns) { patterns.add(patterns.getContext()); } void eraseUnusedTensorPackingOps(func::FuncOp funcOp, IRRewriter& rewriter) { auto eraseUnusedOps = [&](auto tag) { using OpTy = decltype(tag); SmallVector ops; funcOp.walk([&](OpTy op) { ops.push_back(op); }); for (auto op : llvm::reverse(ops)) if (op->use_empty()) rewriter.eraseOp(op); }; eraseUnusedOps(tensor::ConcatOp {}); eraseUnusedOps(tensor::ExtractSliceOp {}); eraseUnusedOps(spatial::SpatExtractRowsOp {}); } } // namespace onnx_mlir