less affine code and better affine helpers
Validate Operations / validate-operations (push) Waiting to run
Validate Operations / validate-operations (push) Waiting to run
This commit is contained in:
@@ -5,5 +5,6 @@
|
||||
#include "MatrixProductLowering.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/TensorSliceUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -77,65 +77,4 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, PatternRewriter& rewri
|
||||
return slicesPerCore;
|
||||
}
|
||||
|
||||
Value extractAxisSlice(
|
||||
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
SmallVector<int64_t> resultShape(sourceType.getShape());
|
||||
resultShape[axis] = size;
|
||||
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
|
||||
|
||||
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
|
||||
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(size);
|
||||
return tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value extractStaticSliceOrIdentity(RewriterBase& rewriter,
|
||||
Location loc,
|
||||
Value source,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes,
|
||||
ArrayRef<OpFoldResult> strides) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
size_t rank = static_cast<size_t>(sourceType.getRank());
|
||||
|
||||
bool isIdentitySlice =
|
||||
sourceType == resultType && sourceType.hasStaticShape() && offsets.size() == rank && sizes.size() == rank
|
||||
&& strides.size() == rank;
|
||||
if (isIdentitySlice) {
|
||||
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
||||
for (auto [dim, offset, size, stride] : llvm::zip_equal(sourceShape, offsets, sizes, strides)) {
|
||||
std::optional<int64_t> staticOffset = mlir::getConstantIntValue(offset);
|
||||
std::optional<int64_t> staticSize = mlir::getConstantIntValue(size);
|
||||
std::optional<int64_t> staticStride = mlir::getConstantIntValue(stride);
|
||||
if (!staticOffset || !staticSize || !staticStride || *staticOffset != 0 || *staticSize != dim || *staticStride != 1) {
|
||||
isIdentitySlice = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isIdentitySlice)
|
||||
return source;
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
Value insertStaticSlice(
|
||||
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
return tensor::InsertSliceOp::create(rewriter,
|
||||
loc,
|
||||
source,
|
||||
dest,
|
||||
offsets,
|
||||
getStaticSizes(rewriter, sourceType.getShape()),
|
||||
getUnitStrides(rewriter, sourceType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -28,21 +28,4 @@ llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||
const mlir::Value& vectorToSlice, mlir::PatternRewriter& rewriter, mlir::Location loc);
|
||||
|
||||
mlir::Value extractAxisSlice(
|
||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
||||
|
||||
mlir::Value extractStaticSliceOrIdentity(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
mlir::RankedTensorType resultType,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> offsets,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> sizes,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> strides);
|
||||
|
||||
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
mlir::Value dest,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> offsets);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user