huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
@@ -0,0 +1,136 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
|
||||
|
||||
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
|
||||
pim::PimSendOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
op.getInput(),
|
||||
getTensorSizeInBytesAttr(rewriter, op.getInput()),
|
||||
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveOp op, PatternRewriter& rewriter) const override {
|
||||
if (op->use_empty()) {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
auto outputType = cast<ShapedType>(op.getResult().getType());
|
||||
Value outputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||
Value received = pim::PimReceiveOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
op.getResult().getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, op.getResult()),
|
||||
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId())))
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, received);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTensorOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(op.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : op.getTargetCoreIds())
|
||||
targetCoreIds.push_back(toPimCoreId(targetCoreId));
|
||||
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelReceiveTensorOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(op.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : op.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(toPimCoreId(sourceCoreId));
|
||||
auto outputType = cast<ShapedType>(op.getOutput().getType());
|
||||
Value outputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||
Value received =
|
||||
pim::PimReceiveTensorOp::create(
|
||||
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, received);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtractRowsLowering : OpRewritePattern<spatial::SpatExtractRowsOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatExtractRowsOp op, PatternRewriter& rewriter) const override {
|
||||
auto inputType = cast<RankedTensorType>(op.getInput().getType());
|
||||
SmallVector<Value> replacements;
|
||||
replacements.reserve(op.getNumResults());
|
||||
for (auto [rowIndex, output] : llvm::enumerate(op.getOutputs())) {
|
||||
auto outputType = cast<RankedTensorType>(output.getType());
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
|
||||
rewriter.getIndexAttr(inputType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
replacements.push_back(
|
||||
tensor::ExtractSliceOp::create(rewriter, op.getLoc(), outputType, op.getInput(), offsets, sizes, strides)
|
||||
.getResult());
|
||||
}
|
||||
rewriter.replaceOp(op, replacements);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatConcatOp op, PatternRewriter& rewriter) const override {
|
||||
auto outputType = cast<ShapedType>(op.getOutput().getType());
|
||||
Value outputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||
Value concatenated =
|
||||
pim::PimConcatOp::create(
|
||||
rewriter, op.getLoc(), op.getOutput().getType(), op.getAxisAttr(), op.getInputs(), outputBuffer)
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, concatenated);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<ChannelSendLowering,
|
||||
ChannelReceiveLowering,
|
||||
ChannelSendTensorLowering,
|
||||
ChannelReceiveTensorLowering,
|
||||
ExtractRowsLowering,
|
||||
ConcatLowering>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user