254 lines
9.6 KiB
C++
254 lines
9.6 KiB
C++
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
struct PackSpatialConcatInputsPattern final : OpRewritePattern<spatial::SpatConcatOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(spatial::SpatConcatOp concatOp, PatternRewriter& rewriter) const override {
|
|
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
|
|
return failure();
|
|
|
|
SmallVector<Value> packedInputs;
|
|
bool changed = false;
|
|
|
|
for (unsigned index = 0; index < concatOp.getInputs().size();) {
|
|
Value input = concatOp.getInputs()[index];
|
|
|
|
if (input.getDefiningOp<tensor::ExtractSliceOp>()) {
|
|
unsigned endIndex = index + 1;
|
|
while (endIndex < concatOp.getInputs().size()
|
|
&& concatOp.getInputs()[endIndex].getDefiningOp<tensor::ExtractSliceOp>())
|
|
++endIndex;
|
|
|
|
Value packedInput = createPackedExtractSliceTensor(
|
|
concatOp.getInputs().slice(index, endIndex - index), rewriter, concatOp.getLoc());
|
|
if (packedInput) {
|
|
packedInputs.push_back(packedInput);
|
|
changed = true;
|
|
index = endIndex;
|
|
continue;
|
|
}
|
|
}
|
|
|
|
auto result = dyn_cast<OpResult>(input);
|
|
if (!result) {
|
|
packedInputs.push_back(input);
|
|
++index;
|
|
continue;
|
|
}
|
|
|
|
Operation* owner = result.getOwner();
|
|
unsigned startIndex = result.getResultNumber();
|
|
unsigned endIndex = index + 1;
|
|
while (endIndex < concatOp.getInputs().size()) {
|
|
auto nextResult = dyn_cast<OpResult>(concatOp.getInputs()[endIndex]);
|
|
if (!nextResult || nextResult.getOwner() != owner
|
|
|| nextResult.getResultNumber() != startIndex + (endIndex - index))
|
|
break;
|
|
++endIndex;
|
|
}
|
|
|
|
unsigned count = endIndex - index;
|
|
Value packedInput;
|
|
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
|
|
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
|
|
|
|
if (packedInput) {
|
|
packedInputs.push_back(packedInput);
|
|
changed = true;
|
|
}
|
|
else {
|
|
for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex)
|
|
packedInputs.push_back(concatOp.getInputs()[oldIndex]);
|
|
}
|
|
|
|
index = endIndex;
|
|
}
|
|
|
|
if (!changed)
|
|
return failure();
|
|
|
|
auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
|
|
auto newConcat = pim::PimConcatOp::create(
|
|
rewriter,
|
|
concatOp.getLoc(),
|
|
concatOp.getOutput().getType(),
|
|
concatOp.getAxisAttr(),
|
|
ValueRange(packedInputs),
|
|
tensor::EmptyOp::create(rewriter, concatOp.getLoc(), outputType.getShape(), outputType.getElementType())
|
|
.getResult());
|
|
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
|
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
|
packedShape[0] *= count;
|
|
return RankedTensorType::get(packedShape, elementType.getElementType());
|
|
}
|
|
|
|
Value extractPackedChunk(
|
|
Value packedValue, RankedTensorType chunkType, unsigned index, OpBuilder& builder, Location loc) {
|
|
auto packedType = dyn_cast<RankedTensorType>(packedValue.getType());
|
|
if (packedType && packedType == chunkType && index == 0)
|
|
return packedValue;
|
|
|
|
SmallVector<OpFoldResult> offsets;
|
|
SmallVector<OpFoldResult> sizes;
|
|
SmallVector<OpFoldResult> strides;
|
|
offsets.reserve(chunkType.getRank());
|
|
sizes.reserve(chunkType.getRank());
|
|
strides.reserve(chunkType.getRank());
|
|
|
|
offsets.push_back(builder.getIndexAttr(static_cast<int64_t>(index) * chunkType.getDimSize(0)));
|
|
sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(0)));
|
|
strides.push_back(builder.getIndexAttr(1));
|
|
for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) {
|
|
offsets.push_back(builder.getIndexAttr(0));
|
|
sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(dim)));
|
|
strides.push_back(builder.getIndexAttr(1));
|
|
}
|
|
|
|
return tensor::ExtractSliceOp::create(builder, loc, chunkType, packedValue, offsets, sizes, strides).getResult();
|
|
}
|
|
|
|
Value createPackedExtractRowsSlice(
|
|
spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, OpBuilder& builder, Location loc) {
|
|
auto rowType = dyn_cast<RankedTensorType>(extractRowsOp.getOutputs()[startIndex].getType());
|
|
auto inputType = dyn_cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
|
if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0)
|
|
return {};
|
|
|
|
int64_t rowsPerValue = rowType.getDimSize(0);
|
|
if (ShapedType::isDynamic(rowsPerValue))
|
|
return {};
|
|
|
|
auto packedType = getPackedTensorType(rowType, static_cast<int64_t>(count));
|
|
SmallVector<OpFoldResult> offsets;
|
|
SmallVector<OpFoldResult> sizes;
|
|
SmallVector<OpFoldResult> strides;
|
|
offsets.reserve(inputType.getRank());
|
|
sizes.reserve(inputType.getRank());
|
|
strides.reserve(inputType.getRank());
|
|
|
|
offsets.push_back(builder.getIndexAttr(static_cast<int64_t>(startIndex) * rowsPerValue));
|
|
sizes.push_back(builder.getIndexAttr(static_cast<int64_t>(count) * rowsPerValue));
|
|
strides.push_back(builder.getIndexAttr(1));
|
|
for (int64_t dim = 1; dim < inputType.getRank(); ++dim) {
|
|
offsets.push_back(builder.getIndexAttr(0));
|
|
sizes.push_back(builder.getIndexAttr(inputType.getDimSize(dim)));
|
|
strides.push_back(builder.getIndexAttr(1));
|
|
}
|
|
|
|
bool coversWholeSource = packedType == inputType && startIndex == 0;
|
|
if (coversWholeSource)
|
|
return extractRowsOp.getInput();
|
|
|
|
return tensor::ExtractSliceOp::create(builder, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides)
|
|
.getResult();
|
|
}
|
|
|
|
Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) {
|
|
if (values.empty())
|
|
return {};
|
|
if (values.size() == 1)
|
|
return values.front();
|
|
|
|
auto firstSliceOp = values.front().getDefiningOp<tensor::ExtractSliceOp>();
|
|
if (!firstSliceOp)
|
|
return {};
|
|
|
|
auto firstType = dyn_cast<RankedTensorType>(firstSliceOp.getResult().getType());
|
|
auto sourceType = dyn_cast<RankedTensorType>(firstSliceOp.getSource().getType());
|
|
if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape()
|
|
|| firstType.getRank() == 0)
|
|
return {};
|
|
|
|
auto hasStaticValues = [](ArrayRef<int64_t> values) {
|
|
return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); });
|
|
};
|
|
if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes())
|
|
|| !hasStaticValues(firstSliceOp.getStaticStrides()))
|
|
return {};
|
|
|
|
ArrayRef<int64_t> firstOffsets = firstSliceOp.getStaticOffsets();
|
|
ArrayRef<int64_t> firstSizes = firstSliceOp.getStaticSizes();
|
|
ArrayRef<int64_t> firstStrides = firstSliceOp.getStaticStrides();
|
|
int64_t rowsPerValue = firstSizes[0];
|
|
if (ShapedType::isDynamic(rowsPerValue))
|
|
return {};
|
|
|
|
for (size_t index = 1; index < values.size(); ++index) {
|
|
auto sliceOp = values[index].getDefiningOp<tensor::ExtractSliceOp>();
|
|
if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource()
|
|
|| sliceOp.getResult().getType() != firstSliceOp.getResult().getType()
|
|
|| !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes())
|
|
|| !hasStaticValues(sliceOp.getStaticStrides()))
|
|
return {};
|
|
|
|
if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides)
|
|
return {};
|
|
|
|
if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast<int64_t>(index) * rowsPerValue)
|
|
return {};
|
|
|
|
for (int64_t dim = 1; dim < firstType.getRank(); ++dim)
|
|
if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim])
|
|
return {};
|
|
}
|
|
|
|
auto packedType = getPackedTensorType(firstType, static_cast<int64_t>(values.size()));
|
|
SmallVector<OpFoldResult> offsets;
|
|
SmallVector<OpFoldResult> sizes;
|
|
SmallVector<OpFoldResult> strides;
|
|
offsets.reserve(firstType.getRank());
|
|
sizes.reserve(firstType.getRank());
|
|
strides.reserve(firstType.getRank());
|
|
|
|
offsets.push_back(builder.getIndexAttr(firstOffsets[0]));
|
|
sizes.push_back(builder.getIndexAttr(rowsPerValue * static_cast<int64_t>(values.size())));
|
|
strides.push_back(builder.getIndexAttr(firstStrides[0]));
|
|
for (int64_t dim = 1; dim < firstType.getRank(); ++dim) {
|
|
offsets.push_back(builder.getIndexAttr(firstOffsets[dim]));
|
|
sizes.push_back(builder.getIndexAttr(firstSizes[dim]));
|
|
strides.push_back(builder.getIndexAttr(firstStrides[dim]));
|
|
}
|
|
|
|
bool coversWholeSource = packedType == sourceType;
|
|
for (int64_t dim = 0; coversWholeSource && dim < sourceType.getRank(); ++dim)
|
|
coversWholeSource = firstOffsets[dim] == 0 && firstStrides[dim] == 1;
|
|
if (coversWholeSource)
|
|
return firstSliceOp.getSource();
|
|
|
|
return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
|
|
.getResult();
|
|
}
|
|
|
|
void populateTensorPackingPatterns(RewritePatternSet& patterns) {
|
|
patterns.add<PackSpatialConcatInputsPattern>(patterns.getContext());
|
|
}
|
|
|
|
void eraseUnusedTensorPackingOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
|
auto eraseUnusedOps = [&](auto tag) {
|
|
using OpTy = decltype(tag);
|
|
SmallVector<OpTy> ops;
|
|
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
|
for (auto op : llvm::reverse(ops))
|
|
if (op->use_empty())
|
|
rewriter.eraseOp(op);
|
|
};
|
|
eraseUnusedOps(tensor::ConcatOp {});
|
|
eraseUnusedOps(tensor::ExtractSliceOp {});
|
|
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
|
}
|
|
|
|
} // namespace onnx_mlir
|