#include "mlir/Transforms/DialectConversion.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { namespace raptor { #include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc" } // namespace raptor static SmallVector getStaticIndexAttrs(Builder& builder, ArrayRef values) { SmallVector attrs; attrs.reserve(values.size()); for (int64_t value : values) attrs.push_back(builder.getIndexAttr(value)); return attrs; } static SmallVector getUnitStrides(Builder& builder, int64_t rank) { SmallVector strides; strides.reserve(rank); for (int64_t dim = 0; dim < rank; ++dim) strides.push_back(builder.getIndexAttr(1)); return strides; } struct LowerFragmentAssemblyReconciliatorPattern : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(spatial::SpatReconciliatorOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { std::optional modeAttr = op.getMode(); if (!modeAttr || *modeAttr != "fragment_assembly") return failure(); auto resultType = dyn_cast(op.getOutput().getType()); if (!resultType || !resultType.hasStaticShape()) return op.emitOpError("fragment assembly lowering requires a static ranked tensor result"); std::optional> operandIndicesAttr = op.getFragmentOperandIndices(); std::optional> sourceOffsetsAttr = op.getFragmentSourceOffsets(); std::optional> fragmentStridesAttr = op.getFragmentStrides(); if (!operandIndicesAttr || !sourceOffsetsAttr || !fragmentStridesAttr) return op.emitOpError("fragment assembly lowering requires explicit fragment metadata"); ArrayRef operandIndices = *operandIndicesAttr; ArrayRef sourceOffsets = *sourceOffsetsAttr; ArrayRef flatOffsets = op.getFragmentOffsets(); ArrayRef flatSizes = op.getFragmentSizes(); ArrayRef flatStrides = *fragmentStridesAttr; int64_t rank = resultType.getRank(); SmallVector fragmentOperands {adaptor.getInput()}; llvm::append_range(fragmentOperands, adaptor.getFragments()); if (failed(validateFragmentAssemblyMetadata( op, rank, fragmentOperands.size(), operandIndices, sourceOffsets, flatOffsets, flatSizes, flatStrides))) return failure(); Value currentOutput = tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult(); for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { int64_t operandIndex = operandIndices[fragmentIndex]; SmallVector fragmentOffsets; for (int64_t dim = 0; dim < rank; ++dim) { int64_t flatIndex = fragmentIndex * rank + dim; if (flatStrides[flatIndex] != 1) return op.emitOpError("fragment assembly lowering only supports unit strides"); fragmentOffsets.push_back(flatOffsets[flatIndex]); } Value source = fragmentOperands[operandIndex]; auto sourceType = dyn_cast(source.getType()); if (!sourceType || !sourceType.hasStaticShape()) return op.emitOpError("fragment assembly lowering requires static ranked tensor operands"); SmallVector fragmentShape; fragmentShape.reserve(rank); for (int64_t dim = 0; dim < rank; ++dim) fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]); Value fragment = source; if (llvm::to_vector(sourceType.getShape()) != fragmentShape || sourceOffsets[fragmentIndex] != 0) { FailureOr> extractOffsets = getStaticSliceOffsetsForElementOffset( op, sourceType, fragmentShape, sourceOffsets[fragmentIndex], "fragment assembly source slice"); if (failed(extractOffsets)) return failure(); fragment = tensor::ExtractSliceOp::create(rewriter, op.getLoc(), source, getStaticIndexAttrs(rewriter, *extractOffsets), getStaticIndexAttrs(rewriter, fragmentShape), getUnitStrides(rewriter, rank)); } currentOutput = tensor::InsertSliceOp::create(rewriter, op.getLoc(), fragment, currentOutput, getStaticIndexAttrs(rewriter, fragmentOffsets), getStaticIndexAttrs(rewriter, fragmentShape), getUnitStrides(rewriter, rank)) .getResult(); } rewriter.replaceOp(op, currentOutput); return success(); } }; void populateInitialPatterns(RewritePatternSet& patterns) { raptor::populateWithGenerated(patterns); populateTransposeLoweringPatterns(patterns); } void populateCoreBodyPatterns(RewritePatternSet& patterns) { raptor::populateWithGenerated(patterns); populateTransposeLoweringPatterns(patterns); patterns.add(patterns.getContext()); } } // namespace onnx_mlir