merge remote changes
This commit is contained in:
@@ -23,7 +23,10 @@ static Value extractSliceAt(
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(size);
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||
SmallVector<int64_t> resultShape(inputType.getShape());
|
||||
resultShape[axis] = size;
|
||||
auto resultType = RankedTensorType::get(resultShape, inputType.getElementType());
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||
@@ -49,12 +52,7 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
int64_t sliceSize = resultType.getShape()[axis];
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, splitOp.getLoc(), TypeRange {resultType}, {}, adaptor.getInput(), [&](Value x) {
|
||||
Value output = extractSliceAt(x, axis, offset, sliceSize, rewriter, splitOp.getLoc());
|
||||
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), output);
|
||||
});
|
||||
outputs.push_back(computeOp.getResult(0));
|
||||
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
|
||||
offset += sliceSize;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user