Files
Raptor/src/PIM/Conversion/SpatialToPim/Patterns/ChannelLowering.cpp
T
2026-06-24 15:52:07 +02:00

104 lines
4.2 KiB
C++

#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static void copyRaptorDebugAttrs(Operation* source, Operation* target) {
for (NamedAttribute attr : source->getAttrs()) {
StringRef name = attr.getName().strref();
if (name.starts_with("raptor."))
target->setAttr(attr.getName(), attr.getValue());
}
}
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getInput());
if (failed(sizeAttr))
return failure();
auto send = pim::PimSendOp::create(rewriter, op.getLoc(), op.getInput(), *sizeAttr, op.getTargetCoreId());
copyRaptorDebugAttrs(op.getOperation(), send.getOperation());
rewriter.eraseOp(op);
return success();
}
};
struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveOp op, PatternRewriter& rewriter) const override {
if (op->use_empty()) {
rewriter.eraseOp(op);
return success();
}
auto outputType = cast<ShapedType>(op.getResult().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, op.getOperation(), op.getResult());
if (failed(sizeAttr))
return failure();
auto receive = pim::PimReceiveOp::create(
rewriter, op.getLoc(), op.getResult().getType(), outputBuffer, *sizeAttr, op.getSourceCoreId());
copyRaptorDebugAttrs(op.getOperation(), receive.getOperation());
Value received = receive.getOutput();
rewriter.replaceOp(op, received);
return success();
}
};
struct ExtractRowsLowering : OpRewritePattern<spatial::SpatExtractRowsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatExtractRowsOp op, PatternRewriter& rewriter) const override {
auto inputType = cast<RankedTensorType>(op.getInput().getType());
SmallVector<Value> replacements;
replacements.reserve(op.getNumResults());
for (auto [rowIndex, output] : llvm::enumerate(op.getOutputs())) {
auto outputType = cast<RankedTensorType>(output.getType());
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
rewriter.getIndexAttr(inputType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
replacements.push_back(
tensor::ExtractSliceOp::create(rewriter, op.getLoc(), outputType, op.getInput(), offsets, sizes, strides)
.getResult());
}
rewriter.replaceOp(op, replacements);
return success();
}
};
struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatConcatOp op, PatternRewriter& rewriter) const override {
auto outputType = cast<ShapedType>(op.getOutput().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value concatenated =
pim::PimConcatOp::create(
rewriter, op.getLoc(), op.getOutput().getType(), op.getAxisAttr(), op.getInputs(), outputBuffer)
.getOutput();
rewriter.replaceOp(op, concatenated);
return success();
}
};
} // namespace
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
patterns.add<ChannelSendLowering, ChannelReceiveLowering, ExtractRowsLowering, ConcatLowering>(patterns.getContext());
}
} // namespace onnx_mlir