use uniqued constant helpers everywhere materialize transposed constants directly
This commit is contained in:
@@ -40,14 +40,6 @@ static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr,
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
|
||||
return normalizeAxesImpl(std::optional<ArrayAttr>(axesAttr), rank);
|
||||
}
|
||||
|
||||
SmallVector<int64_t> normalizeAxes(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
||||
return normalizeAxesImpl(axesAttr, rank);
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
||||
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
|
||||
for (int64_t axis : normalizedAxes)
|
||||
@@ -56,11 +48,7 @@ FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> ax
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(ArrayAttr axesAttr, int64_t rank) {
|
||||
return normalizeAxesChecked(std::optional<ArrayAttr>(axesAttr), rank);
|
||||
}
|
||||
|
||||
Value createAffineApplyOrConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
|
||||
Value createAffineApplyOrFoldedConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
|
||||
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
|
||||
@@ -68,22 +56,22 @@ Value createAffineApplyOrConstant(PatternRewriter& rewriter, Location loc, Affin
|
||||
|
||||
Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) {
|
||||
if (multiplier == 0)
|
||||
return getOrCreateHostIndexConstant(rewriter, anchorOp, 0);
|
||||
return getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
if (multiplier == 1)
|
||||
return value;
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value});
|
||||
return createAffineApplyOrFoldedConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value});
|
||||
}
|
||||
|
||||
Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
|
||||
if (divisor == 1)
|
||||
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrConstant(rewriter, loc, d0 % divisor, ValueRange {value});
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, d0 % divisor, ValueRange {value});
|
||||
}
|
||||
|
||||
Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) {
|
||||
@@ -92,12 +80,12 @@ Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value val
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value});
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value});
|
||||
}
|
||||
|
||||
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, Location loc, OpFoldResult value) {
|
||||
Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) {
|
||||
if (auto attr = dyn_cast<Attribute>(value))
|
||||
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast<IntegerAttr>(attr).getInt());
|
||||
return cast<Value>(value);
|
||||
}
|
||||
|
||||
|
||||
@@ -19,18 +19,12 @@ mlir::FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank);
|
||||
|
||||
int64_t normalizeIndex(int64_t index, int64_t dimSize);
|
||||
|
||||
llvm::SmallVector<int64_t> normalizeAxes(mlir::ArrayAttr axesAttr, int64_t rank);
|
||||
|
||||
llvm::SmallVector<int64_t> normalizeAxes(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(mlir::ArrayAttr axesAttr, int64_t rank);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
|
||||
|
||||
mlir::Value createAffineApplyOrConstant(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::AffineExpr expr,
|
||||
mlir::ValueRange operands);
|
||||
mlir::Value createAffineApplyOrFoldedConstant(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::AffineExpr expr,
|
||||
mlir::ValueRange operands);
|
||||
|
||||
mlir::Value
|
||||
multiplyIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Operation* anchorOp, mlir::Value value, int64_t multiplier);
|
||||
@@ -40,6 +34,6 @@ mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location l
|
||||
mlir::Value
|
||||
floorDivIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor);
|
||||
|
||||
mlir::Value getOrMaterializeIndexValue(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::OpFoldResult value);
|
||||
mlir::Value getOrMaterializeIndexValue(mlir::PatternRewriter& rewriter, mlir::OpFoldResult value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
@@ -19,10 +20,6 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
return getOrMaterializeIndexValue(rewriter, loc, result);
|
||||
}
|
||||
|
||||
static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
APInt lhsConst;
|
||||
if (matchPattern(lhs, m_ConstantInt(&lhsConst)) && lhsConst.isZero())
|
||||
@@ -43,11 +40,12 @@ static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatt
|
||||
return arith::MulIOp::create(rewriter, loc, value, cast<Value>(factor)).getResult();
|
||||
|
||||
if (factorConst.isZero())
|
||||
return arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
if (factorConst.isOne())
|
||||
return value;
|
||||
|
||||
auto factorValue = arith::ConstantIndexOp::create(rewriter, loc, factorConst.getSExtValue()).getResult();
|
||||
auto factorValue =
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), factorConst.getSExtValue());
|
||||
return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult();
|
||||
}
|
||||
|
||||
@@ -61,8 +59,6 @@ int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
int64_t getStaticShapeElementCount(RankedTensorType type) { return getStaticShapeElementCount(type.getShape()); }
|
||||
|
||||
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> permutedShape;
|
||||
permutedShape.reserve(permutation.size());
|
||||
@@ -226,49 +222,6 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewri
|
||||
return slicesPerCore;
|
||||
}
|
||||
|
||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
||||
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) {
|
||||
assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile)));
|
||||
|
||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tiles;
|
||||
|
||||
SmallVector<Value> hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc);
|
||||
size_t numHSlices = hSlices.size();
|
||||
for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) {
|
||||
Value hSlice = hSlices[hSliceId];
|
||||
SmallVector<Value> vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc);
|
||||
for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) {
|
||||
size_t coreId = vSliceId / crossbarCountInCore;
|
||||
Value vSlice = vSlices[vSliceId];
|
||||
tiles[hSliceId][coreId].push_back(vSlice);
|
||||
}
|
||||
}
|
||||
return tiles;
|
||||
}
|
||||
|
||||
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
||||
Type elementType = oldType.getElementType();
|
||||
int64_t shape[2] = {1, length};
|
||||
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
||||
|
||||
auto buildBroadcast = [&](Value input) -> Value {
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
SmallVector<Value> index(oldType.getRank(), zero);
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
|
||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||
};
|
||||
|
||||
if (isCompileTimeComputable(scalarToBroadcast))
|
||||
return buildBroadcast(scalarToBroadcast);
|
||||
|
||||
auto broadcastCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
|
||||
});
|
||||
return broadcastCompute.getResult(0);
|
||||
}
|
||||
|
||||
Value materializeContiguousTensorSlice(Value source,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
@@ -294,7 +247,7 @@ Value materializeContiguousTensorSlice(Value source,
|
||||
Value init = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult();
|
||||
SmallVector<Value> zeroIndices(resultType.getRank());
|
||||
for (Value& zeroIndex : zeroIndices)
|
||||
zeroIndex = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
zeroIndex = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
|
||||
SmallVector<Value> resultIndices;
|
||||
resultIndices.reserve(resultType.getRank());
|
||||
@@ -304,7 +257,7 @@ Value materializeContiguousTensorSlice(Value source,
|
||||
SmallVector<Value> sourceIndices;
|
||||
sourceIndices.reserve(resultType.getRank());
|
||||
for (unsigned idx = 0; idx < resultType.getRank(); ++idx) {
|
||||
Value offsetValue = getIndexValue(offsets[idx], rewriter, loc);
|
||||
Value offsetValue = getOrMaterializeIndexValue(rewriter, offsets[idx]);
|
||||
Value scaledIndex = multiplyIndexValue(resultIndices[idx], strides[idx], rewriter, loc);
|
||||
sourceIndices.push_back(addIndexValues(offsetValue, scaledIndex, rewriter, loc));
|
||||
}
|
||||
@@ -337,8 +290,8 @@ Value materializeContiguousTensorSlice(Value source,
|
||||
}
|
||||
|
||||
Value lower = zeroIndices[dim];
|
||||
Value upper = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(dim)).getResult();
|
||||
Value step = arith::ConstantIndexOp::create(rewriter, loc, 1).getResult();
|
||||
Value upper = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultType.getDimSize(dim));
|
||||
Value step = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
resultIndices.push_back(loop.getInductionVar());
|
||||
@@ -352,17 +305,6 @@ Value materializeContiguousTensorSlice(Value source,
|
||||
return buildLoopNest(buildLoopNest, 0, init);
|
||||
}
|
||||
|
||||
Value extractStaticSlice(PatternRewriter& rewriter,
|
||||
Location loc,
|
||||
Value source,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<OpFoldResult> offsets) {
|
||||
return tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, resultType, source, offsets, getStaticSizes(rewriter, resultType.getShape()),
|
||||
getUnitStrides(rewriter, resultType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value extractAxisSlice(
|
||||
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
|
||||
@@ -18,41 +18,6 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageWidth(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(2);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageHeight(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(3);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageChannel(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(1);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageN(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(0);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getKernelWidth(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(2);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getKernelHeight(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(3);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getFilterCount(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(0);
|
||||
}
|
||||
|
||||
using HSliceId = size_t;
|
||||
using CoreId = size_t;
|
||||
|
||||
@@ -89,17 +54,6 @@ bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[0] == 1;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isVVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[1] == 1;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T getVectorLength(mlir::ArrayRef<T> shape) {
|
||||
assert(isVectorShape(shape));
|
||||
return shape[0] != 1 ? shape[0] : shape[1];
|
||||
}
|
||||
|
||||
inline auto getTensorShape(mlir::Value tensor) {
|
||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||
}
|
||||
@@ -117,8 +71,6 @@ bool hasStaticPositiveShape(mlir::RankedTensorType type);
|
||||
|
||||
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
int64_t getStaticShapeElementCount(mlir::RankedTensorType type);
|
||||
|
||||
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
|
||||
@@ -156,20 +108,6 @@ llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
||||
|
||||
/// Tiles a matrix first across output columns and then across input rows so it
|
||||
/// can be assigned to crossbars grouped by core.
|
||||
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
|
||||
tileMatrix(mlir::Value& matrixToTile,
|
||||
int64_t hSliceSize,
|
||||
int64_t vSliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location& loc);
|
||||
|
||||
mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
int64_t length,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
mlir::Value materializeContiguousTensorSlice(mlir::Value source,
|
||||
mlir::RankedTensorType resultType,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> offsets,
|
||||
@@ -177,12 +115,6 @@ mlir::Value materializeContiguousTensorSlice(mlir::Value source,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
mlir::Value extractStaticSlice(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
mlir::RankedTensorType resultType,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> offsets);
|
||||
|
||||
mlir::Value extractAxisSlice(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
|
||||
@@ -8,10 +8,10 @@
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -38,13 +38,6 @@ static bool isStaticTensorResult(Operation* op) {
|
||||
});
|
||||
}
|
||||
|
||||
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!tensorType)
|
||||
|
||||
@@ -61,9 +61,9 @@ static Value createPaddedRows(Value tensorValue,
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(
|
||||
rewriter, loc, tensorType.getElementType(), rewriter.getZeroAttr(tensorType.getElementType()));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()),
|
||||
tensorType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
@@ -106,7 +106,7 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
||||
}
|
||||
|
||||
auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
|
||||
}
|
||||
|
||||
static Value createConvWeightMatrix(Value w,
|
||||
@@ -158,7 +158,7 @@ static Value buildPackedBias(bool hasBias,
|
||||
|
||||
auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType());
|
||||
auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedBiasAttr, packedBiasType);
|
||||
}
|
||||
|
||||
static Value createIm2colRowComputes(Value x,
|
||||
@@ -214,8 +214,8 @@ static Value createIm2colRowComputes(Value x,
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getFloatAttr(elemType, 0.0), elemType);
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
paddedInput = padOp.getResult();
|
||||
}
|
||||
@@ -223,13 +223,14 @@ static Value createIm2colRowComputes(Value x,
|
||||
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
|
||||
// until the late PIM unrolling step.
|
||||
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
|
||||
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
|
||||
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
|
||||
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
|
||||
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
auto c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
auto c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||
auto cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, numPatches);
|
||||
auto cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, numPatchesPerBatch);
|
||||
auto cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, outWidth);
|
||||
auto cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
|
||||
auto cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
|
||||
|
||||
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
|
||||
rewriter.setInsertionPointToStart(im2colLoop.getBody());
|
||||
|
||||
@@ -83,7 +83,7 @@ static FailureOr<Value> materializeBroadcastedConstantTensor(Value value,
|
||||
}
|
||||
|
||||
auto broadcastedAttr = DenseElementsAttr::get(resultType, resultValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), broadcastedAttr, resultType);
|
||||
}
|
||||
|
||||
static FailureOr<Value>
|
||||
@@ -121,7 +121,7 @@ static FailureOr<Value> materializeReciprocalTensor(Value value,
|
||||
}
|
||||
|
||||
auto reciprocalAttr = DenseFPElementsAttr::get(resultType, reciprocalValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, resultType, reciprocalAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), reciprocalAttr, resultType);
|
||||
}
|
||||
|
||||
template <typename OnnxOp, typename SpatialOp>
|
||||
|
||||
@@ -50,38 +50,17 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr
|
||||
return failure();
|
||||
|
||||
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
||||
}
|
||||
|
||||
static Value transposeForSpatial(Value value,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<int64_t> permutation,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
return transposeMaybeInCompute(value, resultType, permutation, rewriter, loc);
|
||||
}
|
||||
|
||||
static Value
|
||||
multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
return onnx_mlir::multiplyIndexByConstant(rewriter, value.getDefiningOp(), value, multiplier);
|
||||
}
|
||||
|
||||
static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
return onnx_mlir::modIndexByConstant(rewriter, loc, value, divisor);
|
||||
}
|
||||
|
||||
static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
return modIndexByConstant(lane, numOutRows, rewriter, loc);
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaledAttr, denseAttr.getType());
|
||||
}
|
||||
|
||||
static Value createGemmBatchKOffset(
|
||||
Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (numKSlices == 1)
|
||||
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrConstant(
|
||||
return createAffineApplyOrFoldedConstant(
|
||||
rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
||||
}
|
||||
|
||||
@@ -92,11 +71,11 @@ static Value createGemmBatchHOffset(Value lane,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (numOutHSlices == 1)
|
||||
return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrConstant(
|
||||
return createAffineApplyOrFoldedConstant(
|
||||
rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane});
|
||||
}
|
||||
|
||||
@@ -115,9 +94,9 @@ createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatte
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(
|
||||
rewriter, loc, sourceType.getElementType(), rewriter.getZeroAttr(sourceType.getElementType()));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
@@ -149,7 +128,7 @@ static FailureOr<Value> materializePaddedConstantMatrix(Value value,
|
||||
resultValues[row * resultShape[1] + col] = sourceValues[row * sourceShape[1] + col];
|
||||
|
||||
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, resultType, resultAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializePaddedBroadcastedConstantTensor(Value value,
|
||||
@@ -215,7 +194,7 @@ static FailureOr<Value> materializePaddedBroadcastedConstantTensor(Value value,
|
||||
}
|
||||
|
||||
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, resultType, resultAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
|
||||
}
|
||||
|
||||
static FailureOr<Value> prepareBias(Value c,
|
||||
@@ -274,7 +253,7 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
|
||||
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) {
|
||||
Value row = createGemmBatchRow(args.lane, numOutRows, rewriter, loc);
|
||||
Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows);
|
||||
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
|
||||
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||
|
||||
@@ -312,12 +291,7 @@ static Value createDynamicGemmBatchRow(
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
|
||||
}
|
||||
|
||||
static Value createDynamicGemmBatchColumn(
|
||||
Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
return modIndexByConstant(lane, numOutCols, rewriter, loc);
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
|
||||
}
|
||||
|
||||
static Value
|
||||
@@ -385,7 +359,7 @@ static Value createScalarTensorConstant(RankedTensorType scalarType,
|
||||
auto elementType = scalarType.getElementType();
|
||||
auto scalarAttr = rewriter.getFloatAttr(elementType, value);
|
||||
auto denseAttr = DenseElementsAttr::get(scalarType, scalarAttr);
|
||||
return arith::ConstantOp::create(rewriter, loc, scalarType, denseAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), denseAttr, scalarType);
|
||||
}
|
||||
|
||||
static Value createBroadcastedBiasScalar(Value bias,
|
||||
@@ -435,7 +409,7 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) {
|
||||
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
|
||||
Value column = createDynamicGemmBatchColumn(args.lane, numOutCols, rewriter, loc);
|
||||
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols);
|
||||
|
||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
@@ -475,16 +449,16 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
|
||||
Value biasArg = bias ? blockArgs[1] : Value();
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
|
||||
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cLaneCount = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
Value lane = loop.getInductionVar();
|
||||
Value outputAcc = loop.getRegionIterArgs().front();
|
||||
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc);
|
||||
Value column = createDynamicGemmBatchColumn(lane, numOutCols, rewriter, loc);
|
||||
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, lane, numOutCols);
|
||||
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
@@ -522,7 +496,7 @@ static Value createPartialGroupOffset(Value hSlice,
|
||||
Location loc) {
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
return createAffineApplyOrConstant(
|
||||
return createAffineApplyOrFoldedConstant(
|
||||
rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice});
|
||||
}
|
||||
|
||||
@@ -604,7 +578,9 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
|
||||
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
|
||||
Value reduced =
|
||||
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
|
||||
Value hOffset = multiplyIndexByConstant(hSlice, crossbarSize.getValue(), rewriter, loc);
|
||||
Value hOffset =
|
||||
onnx_mlir::multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice,
|
||||
crossbarSize.getValue());
|
||||
if (biasArg) {
|
||||
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
|
||||
Value biasSlice =
|
||||
@@ -620,13 +596,14 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
|
||||
|
||||
Value paddedOutput = outputInit;
|
||||
if (numOutHSlices == 1) {
|
||||
Value hSlice = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value hSlice = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
paddedOutput = buildOutputSlice(outputInit, hSlice);
|
||||
}
|
||||
else {
|
||||
Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cOutHSlices = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cOutHSlices =
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
||||
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(hLoop.getBody());
|
||||
|
||||
@@ -763,7 +740,7 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
if (gemmOpAdaptor.getTransB()) {
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
|
||||
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc);
|
||||
b = transposeMaybeInCompute(b, transposedType, {1, 0}, rewriter, loc);
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ static Value computeLaneIndex(Value lane,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (dimSize == 1)
|
||||
return arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
@@ -85,7 +85,7 @@ static Value computeLaneIndex(Value lane,
|
||||
expr = expr.floorDiv(stride);
|
||||
if (dimSize != 1)
|
||||
expr = expr % dimSize;
|
||||
return createAffineApplyOrConstant(rewriter, loc, expr, ValueRange {lane});
|
||||
return createAffineApplyOrFoldedConstant(rewriter, loc, expr, ValueRange {lane});
|
||||
}
|
||||
|
||||
static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
|
||||
@@ -236,7 +236,7 @@ static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
Location loc) {
|
||||
if (resultType.getRank() == 0) {
|
||||
SmallVector<Value> indices(cast<RankedTensorType>(keepdimsValue.getType()).getRank(),
|
||||
arith::ConstantIndexOp::create(rewriter, loc, 0));
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0));
|
||||
Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices);
|
||||
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
|
||||
}
|
||||
@@ -268,7 +268,7 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
return success();
|
||||
}
|
||||
|
||||
auto axes = normalizeAxesChecked(reduceMeanOp.getAxesAttr(), inputType.getRank());
|
||||
auto axes = normalizeAxesChecked(std::optional<ArrayAttr>(reduceMeanOp.getAxesAttr()), inputType.getRank());
|
||||
if (failed(axes))
|
||||
return failure();
|
||||
SmallVector<bool> reducedAxes = buildReducedAxesMask(*axes, inputType.getRank());
|
||||
|
||||
@@ -31,17 +31,18 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
|
||||
|
||||
static Value
|
||||
createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
if (!useMinimumValue)
|
||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
|
||||
return getOrCreateConstant(rewriter, anchorOp, rewriter.getZeroAttr(elementType), elementType);
|
||||
|
||||
if (auto floatType = dyn_cast<FloatType>(elementType)) {
|
||||
auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true);
|
||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue));
|
||||
return getOrCreateConstant(rewriter, anchorOp, rewriter.getFloatAttr(floatType, minValue), elementType);
|
||||
}
|
||||
|
||||
if (auto integerType = dyn_cast<IntegerType>(elementType)) {
|
||||
auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth());
|
||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue));
|
||||
return getOrCreateConstant(rewriter, anchorOp, rewriter.getIntegerAttr(integerType, minValue), elementType);
|
||||
}
|
||||
|
||||
llvm_unreachable("unsupported pool element type");
|
||||
@@ -148,7 +149,7 @@ static FailureOr<Value> createAverageScaleTensor(ConversionPatternRewriter& rewr
|
||||
}
|
||||
|
||||
auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaleAttr, scaleType);
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
@@ -265,13 +266,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
|
||||
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
|
||||
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount);
|
||||
Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth);
|
||||
Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth);
|
||||
Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||
Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||
Value cOutputPatchCount = getOrCreateIndexConstant(rewriter, anchorOp, outputPatchCount);
|
||||
Value cOutputPixelsPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, outputHeight * outputWidth);
|
||||
Value cOutputWidth = getOrCreateIndexConstant(rewriter, anchorOp, outputWidth);
|
||||
Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
|
||||
Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
|
||||
|
||||
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit});
|
||||
rewriter.setInsertionPointToStart(outputLoop.getBody());
|
||||
@@ -296,14 +298,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||
Value paddedInH = windowBaseH;
|
||||
if (kernelH * dilationHeight != 0) {
|
||||
Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight);
|
||||
Value kernelHOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelH * dilationHeight);
|
||||
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset);
|
||||
}
|
||||
|
||||
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||
Value paddedInW = windowBaseW;
|
||||
if (kernelW * dilationWidth != 0) {
|
||||
Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth);
|
||||
Value kernelWOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelW * dilationWidth);
|
||||
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
|
||||
}
|
||||
|
||||
|
||||
@@ -52,9 +52,10 @@ static Value buildLoopSoftmaxNest(Value input,
|
||||
if (axis == inputType.getRank() - 1)
|
||||
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
|
||||
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis));
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||
Value cUpper = getOrCreateIndexConstant(rewriter, anchorOp, inputType.getDimSize(axis));
|
||||
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
@@ -17,9 +17,10 @@ namespace {
|
||||
|
||||
static Value buildNearestAsymmetricIndex(
|
||||
Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim);
|
||||
Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim);
|
||||
Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1);
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value cInputDim = getOrCreateIndexConstant(rewriter, anchorOp, inputDim);
|
||||
Value cOutputDim = getOrCreateIndexConstant(rewriter, anchorOp, outputDim);
|
||||
Value cInputDimLast = getOrCreateIndexConstant(rewriter, anchorOp, inputDim - 1);
|
||||
Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim);
|
||||
Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim);
|
||||
return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
|
||||
@@ -37,12 +38,13 @@ static Value buildNearestResizeLoop(Value input,
|
||||
SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0));
|
||||
Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1));
|
||||
Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2));
|
||||
Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3));
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||
Value cOutputN = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(0));
|
||||
Value cOutputC = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(1));
|
||||
Value cOutputH = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(2));
|
||||
Value cOutputW = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(3));
|
||||
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -30,6 +32,54 @@ static Value createTransposeInit(Value input,
|
||||
return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult();
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializeTransposedConstant(Value input,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<int64_t> permutation,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto denseAttr = getHostConstDenseElementsAttr(input);
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
auto inputType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!inputType || !inputType.hasStaticShape() || !resultType.hasStaticShape()
|
||||
|| inputType.getRank() != resultType.getRank()
|
||||
|| static_cast<int64_t>(permutation.size()) != inputType.getRank()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (denseAttr.isSplat())
|
||||
return getOrCreateConstant(rewriter,
|
||||
rewriter.getInsertionBlock()->getParentOp(),
|
||||
DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>()),
|
||||
resultType);
|
||||
|
||||
SmallVector<Attribute> inputValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(inputValues.size());
|
||||
SmallVector<int64_t> inputStrides = computeRowMajorStrides(inputType.getShape());
|
||||
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
|
||||
SmallVector<int64_t> inputIndices(inputType.getRank(), 0);
|
||||
|
||||
for (auto [linearIndex, value] : llvm::enumerate(inputValues)) {
|
||||
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||
for (int64_t dim = 0; dim < inputType.getRank(); ++dim) {
|
||||
inputIndices[dim] = inputStrides.empty() ? 0 : remaining / inputStrides[dim];
|
||||
remaining = inputStrides.empty() ? 0 : remaining % inputStrides[dim];
|
||||
}
|
||||
|
||||
int64_t resultLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < resultType.getRank(); ++dim)
|
||||
resultLinearIndex += inputIndices[permutation[dim]] * resultStrides[dim];
|
||||
|
||||
resultValues[resultLinearIndex] = value;
|
||||
}
|
||||
|
||||
return getOrCreateConstant(rewriter,
|
||||
rewriter.getInsertionBlock()->getParentOp(),
|
||||
DenseElementsAttr::get(resultType, resultValues),
|
||||
resultType);
|
||||
}
|
||||
|
||||
struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -44,6 +94,14 @@ struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
|
||||
auto permutation = getTransposePermutationChecked(transposeOp.getPermAttr(), inputType.getRank());
|
||||
if (failed(permutation))
|
||||
return failure();
|
||||
if (isCompileTimeComputable(adaptor.getData())) {
|
||||
auto constantTranspose =
|
||||
materializeTransposedConstant(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
|
||||
if (succeeded(constantTranspose)) {
|
||||
rewriter.replaceOp(transposeOp, *constantTranspose);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
|
||||
Value transposed =
|
||||
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation)
|
||||
|
||||
Reference in New Issue
Block a user