#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; } struct ChannelSendLowering : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override { pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), getTensorSizeInBytesAttr(rewriter, op.getInput()), rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId()))); rewriter.eraseOp(op); return success(); } }; struct ChannelReceiveLowering : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spatial::SpatChannelReceiveOp op, PatternRewriter& rewriter) const override { if (op->use_empty()) { rewriter.eraseOp(op); return success(); } auto outputType = cast(op.getResult().getType()); Value outputBuffer = tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult(); Value received = pim::PimReceiveOp::create(rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, getTensorSizeInBytesAttr(rewriter, op.getResult()), rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId()))) .getOutput(); rewriter.replaceOp(op, received); return success(); } }; struct ChannelSendTensorLowering : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override { SmallVector targetCoreIds; targetCoreIds.reserve(op.getTargetCoreIds().size()); for (int32_t targetCoreId : op.getTargetCoreIds()) targetCoreIds.push_back(toPimCoreId(targetCoreId)); pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds)); rewriter.eraseOp(op); return success(); } }; struct ChannelReceiveTensorLowering : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override { SmallVector sourceCoreIds; sourceCoreIds.reserve(op.getSourceCoreIds().size()); for (int32_t sourceCoreId : op.getSourceCoreIds()) sourceCoreIds.push_back(toPimCoreId(sourceCoreId)); auto outputType = cast(op.getOutput().getType()); Value outputBuffer = tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult(); Value received = pim::PimReceiveTensorOp::create( rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds)) .getOutput(); rewriter.replaceOp(op, received); return success(); } }; struct ExtractRowsLowering : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spatial::SpatExtractRowsOp op, PatternRewriter& rewriter) const override { auto inputType = cast(op.getInput().getType()); SmallVector replacements; replacements.reserve(op.getNumResults()); for (auto [rowIndex, output] : llvm::enumerate(op.getOutputs())) { auto outputType = cast(output.getType()); SmallVector offsets = { rewriter.getIndexAttr(static_cast(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)}; SmallVector sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)), rewriter.getIndexAttr(inputType.getDimSize(1))}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; replacements.push_back( tensor::ExtractSliceOp::create(rewriter, op.getLoc(), outputType, op.getInput(), offsets, sizes, strides) .getResult()); } rewriter.replaceOp(op, replacements); return success(); } }; struct ConcatLowering : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spatial::SpatConcatOp op, PatternRewriter& rewriter) const override { auto outputType = cast(op.getOutput().getType()); Value outputBuffer = tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult(); Value concatenated = pim::PimConcatOp::create( rewriter, op.getLoc(), op.getOutput().getType(), op.getAxisAttr(), op.getInputs(), outputBuffer) .getOutput(); rewriter.replaceOp(op, concatenated); return success(); } }; } // namespace void populateChannelLoweringPatterns(RewritePatternSet& patterns) { patterns.add(patterns.getContext()); } } // namespace onnx_mlir