huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
@@ -0,0 +1,113 @@
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
// Replaces concat-of-adjacent-slices with one packed slice to keep batch sends compact.
|
||||
struct FoldConcatOfContiguousSlices : OpRewritePattern<tensor::ConcatOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::ConcatOp op, PatternRewriter& rewriter) const override {
|
||||
if (op.getDim() != 0)
|
||||
return failure();
|
||||
|
||||
Value packed = createPackedExtractSliceTensor(op.getInputs(), rewriter, op.getLoc());
|
||||
if (!packed)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOp(op, packed);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
||||
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
||||
packedShape[0] *= count;
|
||||
return RankedTensorType::get(packedShape, elementType.getElementType());
|
||||
}
|
||||
|
||||
Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) {
|
||||
if (values.empty())
|
||||
return {};
|
||||
if (values.size() == 1)
|
||||
return values.front();
|
||||
|
||||
auto firstSliceOp = values.front().getDefiningOp<tensor::ExtractSliceOp>();
|
||||
if (!firstSliceOp)
|
||||
return {};
|
||||
|
||||
auto firstType = dyn_cast<RankedTensorType>(firstSliceOp.getResult().getType());
|
||||
auto sourceType = dyn_cast<RankedTensorType>(firstSliceOp.getSource().getType());
|
||||
if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape()
|
||||
|| firstType.getRank() == 0)
|
||||
return {};
|
||||
|
||||
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
||||
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
};
|
||||
if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(firstSliceOp.getStaticStrides()))
|
||||
return {};
|
||||
|
||||
ArrayRef<int64_t> firstOffsets = firstSliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> firstSizes = firstSliceOp.getStaticSizes();
|
||||
ArrayRef<int64_t> firstStrides = firstSliceOp.getStaticStrides();
|
||||
int64_t rowsPerValue = firstSizes[0];
|
||||
if (ShapedType::isDynamic(rowsPerValue))
|
||||
return {};
|
||||
|
||||
for (size_t index = 1; index < values.size(); ++index) {
|
||||
auto sliceOp = values[index].getDefiningOp<tensor::ExtractSliceOp>();
|
||||
if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource()
|
||||
|| sliceOp.getResult().getType() != firstSliceOp.getResult().getType()
|
||||
|| !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes())
|
||||
|| !hasStaticValues(sliceOp.getStaticStrides()))
|
||||
return {};
|
||||
|
||||
if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides)
|
||||
return {};
|
||||
|
||||
if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast<int64_t>(index) * rowsPerValue)
|
||||
return {};
|
||||
|
||||
for (int64_t dim = 1; dim < firstType.getRank(); ++dim)
|
||||
if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim])
|
||||
return {};
|
||||
}
|
||||
|
||||
auto packedType = getPackedTensorType(firstType, static_cast<int64_t>(values.size()));
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides;
|
||||
offsets.reserve(firstType.getRank());
|
||||
sizes.reserve(firstType.getRank());
|
||||
strides.reserve(firstType.getRank());
|
||||
|
||||
offsets.push_back(builder.getIndexAttr(firstOffsets[0]));
|
||||
sizes.push_back(builder.getIndexAttr(rowsPerValue * static_cast<int64_t>(values.size())));
|
||||
strides.push_back(builder.getIndexAttr(firstStrides[0]));
|
||||
for (int64_t dim = 1; dim < firstType.getRank(); ++dim) {
|
||||
offsets.push_back(builder.getIndexAttr(firstOffsets[dim]));
|
||||
sizes.push_back(builder.getIndexAttr(firstSizes[dim]));
|
||||
strides.push_back(builder.getIndexAttr(firstStrides[dim]));
|
||||
}
|
||||
|
||||
bool coversWholeSource = packedType == sourceType;
|
||||
for (int64_t dim = 0; coversWholeSource && dim < sourceType.getRank(); ++dim)
|
||||
coversWholeSource = firstOffsets[dim] == 0 && firstStrides[dim] == 1;
|
||||
if (coversWholeSource)
|
||||
return firstSliceOp.getSource();
|
||||
|
||||
return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
void populateTensorPackingPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<FoldConcatOfContiguousSlices>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user