This commit is contained in:
@@ -10,6 +10,7 @@ add_pim_library(OMONNXToSpatial
|
||||
Patterns/Post.cpp
|
||||
Patterns/GeneratedConversion.cpp
|
||||
Patterns/Math/Conv.cpp
|
||||
Patterns/Math/ConvGeometry.cpp
|
||||
Patterns/Math/Elementwise.cpp
|
||||
Patterns/Math/Gemm.cpp
|
||||
Patterns/Math/MatMul.cpp
|
||||
@@ -30,7 +31,7 @@ add_pim_library(OMONNXToSpatial
|
||||
LowerSpatialPlansPass.cpp
|
||||
Common/AttributeUtils.cpp
|
||||
Common/ComputeRegionBuilder.cpp
|
||||
Common/IndexingUtils.cpp
|
||||
Common/MatrixProductLowering.cpp
|
||||
Common/ShapeTilingUtils.cpp
|
||||
Common/WeightMaterialization.cpp
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "AttributeUtils.hpp"
|
||||
#include "ComputeRegionBuilder.hpp"
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "MatrixProductLowering.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "IndexingUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
||||
|
||||
FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank) {
|
||||
int64_t normalizedAxis = normalizeAxis(axis, rank);
|
||||
if (normalizedAxis < 0 || normalizedAxis >= rank)
|
||||
return failure();
|
||||
return normalizedAxis;
|
||||
}
|
||||
|
||||
int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
|
||||
|
||||
static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
||||
SmallVector<int64_t> normalizedAxes;
|
||||
if (!axesAttr) {
|
||||
normalizedAxes.reserve(rank);
|
||||
for (int64_t axis = 0; axis < rank; ++axis)
|
||||
normalizedAxes.push_back(axis);
|
||||
}
|
||||
else {
|
||||
normalizedAxes.reserve(axesAttr->size());
|
||||
for (Attribute attr : *axesAttr)
|
||||
normalizedAxes.push_back(normalizeAxis(cast<IntegerAttr>(attr).getInt(), rank));
|
||||
llvm::sort(normalizedAxes);
|
||||
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
|
||||
}
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
||||
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
|
||||
for (int64_t axis : normalizedAxes)
|
||||
if (axis < 0 || axis >= rank)
|
||||
return failure();
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,20 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
int64_t normalizeAxis(int64_t axis, int64_t rank);
|
||||
|
||||
mlir::FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank);
|
||||
|
||||
int64_t normalizeIndex(int64_t index, int64_t dimSize);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,48 @@
|
||||
#include "MatrixProductLowering.hpp"
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> highPads;
|
||||
highPads.reserve(sourceType.getRank());
|
||||
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
||||
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
||||
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
Value createPaddedInputCompute(Value input,
|
||||
RankedTensorType paddedInputType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (inputType == paddedInputType)
|
||||
return input;
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
||||
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::Value createZeroPaddedTensor(mlir::Value value,
|
||||
mlir::RankedTensorType resultType,
|
||||
mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
mlir::Value createPaddedInputCompute(mlir::Value input,
|
||||
mlir::RankedTensorType paddedInputType,
|
||||
mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -3,9 +3,6 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
@@ -15,73 +12,6 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(RankedTensorType type) {
|
||||
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
|
||||
}
|
||||
|
||||
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> permutedShape;
|
||||
permutedShape.reserve(permutation.size());
|
||||
for (int64_t axis : permutation)
|
||||
permutedShape.push_back(shape[axis]);
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> inversePermutation(permutation.size());
|
||||
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
||||
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
||||
return inversePermutation;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
|
||||
SmallVector<int64_t> permutation;
|
||||
if (!permAttr) {
|
||||
permutation.reserve(rank);
|
||||
for (int64_t dim = rank - 1; dim >= 0; --dim)
|
||||
permutation.push_back(dim);
|
||||
return permutation;
|
||||
}
|
||||
|
||||
if (static_cast<int64_t>(permAttr->size()) != rank)
|
||||
return failure();
|
||||
|
||||
permutation.reserve(permAttr->size());
|
||||
SmallVector<bool> seen(rank, false);
|
||||
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
|
||||
int64_t axis = attr.getInt();
|
||||
if (axis < 0 || axis >= rank || seen[axis])
|
||||
return failure();
|
||||
seen[axis] = true;
|
||||
permutation.push_back(axis);
|
||||
}
|
||||
return permutation;
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(shape.size());
|
||||
for (int64_t dim : shape)
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
return sizes;
|
||||
}
|
||||
|
||||
SmallVector<Value> sliceTensor(
|
||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, PatternRewriter& rewriter, Location loc) {
|
||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||
|
||||
@@ -1,89 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
using HSliceId = size_t;
|
||||
using CoreId = size_t;
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr C ceilIntegerDivide(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return 1 + (ac - 1) / bc;
|
||||
}
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return {ceilIntegerDivide(ac, bc), ac % bc};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[0] == 1;
|
||||
}
|
||||
|
||||
inline auto getTensorShape(mlir::Value tensor) {
|
||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||
}
|
||||
|
||||
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
|
||||
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
|
||||
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
|
||||
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
|
||||
&& lhsType.getShape() == rhsType.getShape();
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
bool hasStaticPositiveShape(mlir::RankedTensorType type);
|
||||
|
||||
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
|
||||
int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
/// Slices a statically shaped tensor along one axis into contiguous pieces of
|
||||
/// at most `sliceSize` elements.
|
||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PlanLowering.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns/Math/ConvGeometry.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -42,59 +43,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
struct ConvLoweringState {
|
||||
Value x;
|
||||
Value w;
|
||||
Value b;
|
||||
RankedTensorType xType;
|
||||
RankedTensorType wType;
|
||||
RankedTensorType outType;
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t padHeightBegin;
|
||||
int64_t padHeightEnd;
|
||||
int64_t padWidthBegin;
|
||||
int64_t padWidthEnd;
|
||||
int64_t strideHeight;
|
||||
int64_t strideWidth;
|
||||
int64_t dilationHeight;
|
||||
int64_t dilationWidth;
|
||||
bool hasBias;
|
||||
};
|
||||
|
||||
struct ConvGeometry {
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t k;
|
||||
int64_t c;
|
||||
int64_t p;
|
||||
int64_t xbarSize;
|
||||
int64_t pack;
|
||||
uint64_t im2colElements;
|
||||
bool hasBias;
|
||||
bool isDepthwise;
|
||||
};
|
||||
|
||||
struct ConvLoweringDecision {
|
||||
PimConvLoweringType strategy;
|
||||
std::string reason;
|
||||
@@ -108,19 +56,6 @@ struct PreparedConvInput {
|
||||
RankedTensorType type;
|
||||
};
|
||||
|
||||
struct RowInterval {
|
||||
int64_t begin = 0;
|
||||
int64_t end = 0;
|
||||
};
|
||||
|
||||
struct ConvRowDemand {
|
||||
RowInterval outputRows;
|
||||
RowInterval neededInputRows;
|
||||
RowInterval acquiredInputRows;
|
||||
int64_t topHaloRows = 0;
|
||||
int64_t bottomHaloRows = 0;
|
||||
};
|
||||
|
||||
struct ConvStrategyEstimate {
|
||||
uint64_t estimatedMvmCount = 0;
|
||||
uint64_t estimatedReductionVAddCount = 0;
|
||||
@@ -291,9 +226,6 @@ static FailureOr<Value> createRowStripPackedRows(Value rows,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc);
|
||||
|
||||
static bool
|
||||
isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup);
|
||||
static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor);
|
||||
static FailureOr<ConvLoweringState> analyzeConvLoweringState(ONNXConvOp convOp, Value x, Value w, Value b);
|
||||
|
||||
static StringRef stringifyDistributedConvBarrierKind(DistributedConvBarrierKind kind) {
|
||||
@@ -391,34 +323,6 @@ static ConvStrategyEstimate estimateConvStrategy(const ConvGeometry& geo,
|
||||
return estimate;
|
||||
}
|
||||
|
||||
static ConvGeometry buildConvGeometry(const ConvLoweringState& state) {
|
||||
ConvGeometry geo {
|
||||
state.batchSize,
|
||||
state.numChannelsIn,
|
||||
state.xHeight,
|
||||
state.xWidth,
|
||||
state.numChannelsOut,
|
||||
state.wHeight,
|
||||
state.wWidth,
|
||||
state.outHeight,
|
||||
state.outWidth,
|
||||
state.group,
|
||||
state.numChannelsInPerGroup,
|
||||
state.numChannelsOutPerGroup,
|
||||
state.numChannelsInPerGroup * state.wHeight * state.wWidth,
|
||||
state.numChannelsOutPerGroup,
|
||||
state.batchSize * state.outHeight * state.outWidth,
|
||||
static_cast<int64_t>(crossbarSize.getValue()),
|
||||
1,
|
||||
0,
|
||||
state.hasBias,
|
||||
isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup),
|
||||
};
|
||||
geo.pack = std::max<int64_t>(1, geo.xbarSize / std::max<int64_t>(geo.k, geo.c));
|
||||
geo.im2colElements = static_cast<uint64_t>(std::max<int64_t>(0, geo.p)) * static_cast<uint64_t>(std::max<int64_t>(0, geo.k));
|
||||
return geo;
|
||||
}
|
||||
|
||||
static std::string formatShape(ArrayRef<int64_t> dims) {
|
||||
std::string text;
|
||||
llvm::raw_string_ostream os(text);
|
||||
@@ -563,36 +467,10 @@ classifyDistributedBinaryConsumer(Operation* user,
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows,
|
||||
int64_t inputHeight,
|
||||
int64_t kernelH,
|
||||
int64_t strideH,
|
||||
int64_t dilationH,
|
||||
int64_t padTop) {
|
||||
const int64_t rawBegin = outputRows.begin * strideH - padTop;
|
||||
const int64_t rawEnd = (outputRows.end - 1) * strideH - padTop + dilationH * (kernelH - 1) + 1;
|
||||
return {std::max<int64_t>(0, rawBegin), std::min<int64_t>(inputHeight, rawEnd)};
|
||||
}
|
||||
|
||||
static bool covers(RowInterval acquired, RowInterval needed) {
|
||||
return acquired.begin <= needed.begin && acquired.end >= needed.end;
|
||||
}
|
||||
|
||||
static ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state) {
|
||||
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
|
||||
const int64_t rawEnd =
|
||||
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
|
||||
RowInterval neededInputRows = computeConvInputRowsForOutputRows(
|
||||
outputRows, state.xHeight, state.wHeight, state.strideHeight, state.dilationHeight, state.padHeightBegin);
|
||||
ConvRowDemand demand;
|
||||
demand.outputRows = outputRows;
|
||||
demand.neededInputRows = neededInputRows;
|
||||
demand.acquiredInputRows = neededInputRows;
|
||||
demand.topHaloRows = std::max<int64_t>(0, -rawBegin);
|
||||
demand.bottomHaloRows = std::max<int64_t>(0, rawEnd - state.xHeight);
|
||||
return demand;
|
||||
}
|
||||
|
||||
static bool canConsumeRowStripHwcInput(const ConvLoweringState& state, StringRef& failureReason) {
|
||||
if (state.batchSize != 1) {
|
||||
failureReason = "unsupported_batch";
|
||||
@@ -1250,19 +1128,6 @@ static void reportConvLoweringDecision(ONNXConvOp convOp,
|
||||
rewriteConvLoweringReport(reportEntries);
|
||||
}
|
||||
|
||||
static uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) {
|
||||
const uint64_t patchElements = static_cast<uint64_t>(std::max<int64_t>(1, geo.k));
|
||||
uint64_t chunkPositions = std::max<uint64_t>(1, pimConvIm2colMaxElements / patchElements);
|
||||
chunkPositions = std::min<uint64_t>(chunkPositions, static_cast<uint64_t>(std::max<int64_t>(1, geo.p)));
|
||||
chunkPositions = std::min<uint64_t>(chunkPositions, std::max<uint64_t>(1, pimConvStreamChunkPositions));
|
||||
|
||||
if (packFactor > 1 && chunkPositions > static_cast<uint64_t>(packFactor)) {
|
||||
chunkPositions -= chunkPositions % static_cast<uint64_t>(packFactor);
|
||||
chunkPositions = std::max<uint64_t>(chunkPositions, static_cast<uint64_t>(packFactor));
|
||||
}
|
||||
return std::max<uint64_t>(1, chunkPositions);
|
||||
}
|
||||
|
||||
static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location loc) {
|
||||
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||
if (biasType.getRank() != 1)
|
||||
@@ -1278,11 +1143,6 @@ static Value expandBiasIfNeeded(Value bias, PatternRewriter& rewriter, Location
|
||||
});
|
||||
}
|
||||
|
||||
static bool
|
||||
isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) {
|
||||
return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0;
|
||||
}
|
||||
|
||||
static int64_t findLargestDivisorAtMost(int64_t value, int64_t limit) {
|
||||
assert(value > 0 && "expected positive value");
|
||||
limit = std::min(value, limit);
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
#include "ConvGeometry.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup) {
|
||||
return group == numChannelsIn && numChannelsInPerGroup == 1 && numChannelsOut % group == 0;
|
||||
}
|
||||
|
||||
ConvGeometry buildConvGeometry(const ConvLoweringState& state) {
|
||||
ConvGeometry geo {
|
||||
state.batchSize,
|
||||
state.numChannelsIn,
|
||||
state.xHeight,
|
||||
state.xWidth,
|
||||
state.numChannelsOut,
|
||||
state.wHeight,
|
||||
state.wWidth,
|
||||
state.outHeight,
|
||||
state.outWidth,
|
||||
state.group,
|
||||
state.numChannelsInPerGroup,
|
||||
state.numChannelsOutPerGroup,
|
||||
state.numChannelsInPerGroup * state.wHeight * state.wWidth,
|
||||
state.numChannelsOutPerGroup,
|
||||
state.batchSize * state.outHeight * state.outWidth,
|
||||
static_cast<int64_t>(crossbarSize.getValue()),
|
||||
1,
|
||||
0,
|
||||
state.hasBias,
|
||||
isDepthwiseConv(state.group, state.numChannelsIn, state.numChannelsOut, state.numChannelsInPerGroup),
|
||||
};
|
||||
geo.pack = std::max<int64_t>(1, geo.xbarSize / std::max<int64_t>(geo.k, geo.c));
|
||||
geo.im2colElements = static_cast<uint64_t>(std::max<int64_t>(0, geo.p)) * static_cast<uint64_t>(std::max<int64_t>(0, geo.k));
|
||||
return geo;
|
||||
}
|
||||
|
||||
uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor) {
|
||||
const uint64_t patchElements = static_cast<uint64_t>(std::max<int64_t>(1, geo.k));
|
||||
uint64_t chunkPositions = std::max<uint64_t>(1, pimConvIm2colMaxElements / patchElements);
|
||||
chunkPositions = std::min<uint64_t>(chunkPositions, static_cast<uint64_t>(std::max<int64_t>(1, geo.p)));
|
||||
chunkPositions = std::min<uint64_t>(chunkPositions, std::max<uint64_t>(1, pimConvStreamChunkPositions));
|
||||
|
||||
if (packFactor > 1 && chunkPositions > static_cast<uint64_t>(packFactor)) {
|
||||
chunkPositions -= chunkPositions % static_cast<uint64_t>(packFactor);
|
||||
chunkPositions = std::max<uint64_t>(chunkPositions, static_cast<uint64_t>(packFactor));
|
||||
}
|
||||
return std::max<uint64_t>(1, chunkPositions);
|
||||
}
|
||||
|
||||
RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state) {
|
||||
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
|
||||
const int64_t rawEnd =
|
||||
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
|
||||
return {std::max<int64_t>(0, rawBegin), std::min<int64_t>(state.xHeight, rawEnd)};
|
||||
}
|
||||
|
||||
ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state) {
|
||||
ConvRowDemand demand;
|
||||
demand.outputRows = outputRows;
|
||||
demand.neededInputRows = computeConvInputRowsForOutputRows(outputRows, state);
|
||||
demand.acquiredInputRows = demand.neededInputRows;
|
||||
|
||||
const int64_t rawBegin = outputRows.begin * state.strideHeight - state.padHeightBegin;
|
||||
const int64_t rawEnd =
|
||||
(outputRows.end - 1) * state.strideHeight - state.padHeightBegin + state.dilationHeight * (state.wHeight - 1) + 1;
|
||||
demand.topHaloRows = std::max<int64_t>(0, -rawBegin);
|
||||
demand.bottomHaloRows = std::max<int64_t>(0, rawEnd - state.xHeight);
|
||||
demand.acquiredInputRows = demand.neededInputRows;
|
||||
return demand;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,86 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ConvLoweringState {
|
||||
mlir::Value x;
|
||||
mlir::Value w;
|
||||
mlir::Value b;
|
||||
mlir::RankedTensorType xType;
|
||||
mlir::RankedTensorType wType;
|
||||
mlir::RankedTensorType outType;
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t padHeightBegin;
|
||||
int64_t padHeightEnd;
|
||||
int64_t padWidthBegin;
|
||||
int64_t padWidthEnd;
|
||||
int64_t strideHeight;
|
||||
int64_t strideWidth;
|
||||
int64_t dilationHeight;
|
||||
int64_t dilationWidth;
|
||||
bool hasBias;
|
||||
};
|
||||
|
||||
struct ConvGeometry {
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t k;
|
||||
int64_t c;
|
||||
int64_t p;
|
||||
int64_t xbarSize;
|
||||
int64_t pack;
|
||||
uint64_t im2colElements;
|
||||
bool hasBias;
|
||||
bool isDepthwise;
|
||||
};
|
||||
|
||||
struct RowInterval {
|
||||
int64_t begin = 0;
|
||||
int64_t end = 0;
|
||||
};
|
||||
|
||||
struct ConvRowDemand {
|
||||
RowInterval outputRows;
|
||||
RowInterval neededInputRows;
|
||||
RowInterval acquiredInputRows;
|
||||
int64_t topHaloRows = 0;
|
||||
int64_t bottomHaloRows = 0;
|
||||
};
|
||||
|
||||
bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup);
|
||||
|
||||
ConvGeometry buildConvGeometry(const ConvLoweringState& state);
|
||||
|
||||
uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor);
|
||||
|
||||
RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state);
|
||||
|
||||
ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -87,28 +87,6 @@ static Value createGemmBatchHOffset(Value lane,
|
||||
rewriter.getInsertionBlock()->getParentOp());
|
||||
}
|
||||
|
||||
static Value
|
||||
createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> highPads;
|
||||
highPads.reserve(sourceType.getRank());
|
||||
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
||||
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
||||
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializePaddedConstantMatrix(Value value,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
@@ -232,22 +210,6 @@ static Value extractATile(
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, aTileType, a, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
static Value createPaddedInputCompute(Value input,
|
||||
RankedTensorType paddedInputType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (inputType == paddedInputType)
|
||||
return input;
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
||||
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
||||
});
|
||||
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static FailureOr<spatial::SpatComputeBatch> createVmmBatch(Value a,
|
||||
Value b,
|
||||
RankedTensorType aType,
|
||||
|
||||
@@ -255,42 +255,6 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
|
||||
return createONNXTranspose(resultType, {0, 2, 1});
|
||||
}
|
||||
|
||||
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> highPads;
|
||||
highPads.reserve(sourceType.getRank());
|
||||
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
||||
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
||||
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static Value createPaddedBatchedInputCompute(Value input,
|
||||
RankedTensorType paddedInputType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (inputType == paddedInputType)
|
||||
return input;
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
||||
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializePaddedBatchedWeight(Value value,
|
||||
ArrayRef<int64_t> sourceBatchShape,
|
||||
ArrayRef<int64_t> targetBatchShape,
|
||||
@@ -1055,7 +1019,7 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
auto paddedRhs =
|
||||
materializePaddedBatchedWeight(plan.rhs, plan.rhsBatchShape, plan.outputBatchShape, paddedRhsType, rewriter);
|
||||
if (succeeded(paddedRhs)) {
|
||||
Value paddedLhs = createPaddedBatchedInputCompute(plan.lhs, paddedLhsType, rewriter, loc);
|
||||
Value paddedLhs = createPaddedInputCompute(plan.lhs, paddedLhsType, rewriter, loc);
|
||||
const int64_t laneCount = plan.batch * plan.m * numKSlices * numOutHSlices;
|
||||
auto partialPiecesType = RankedTensorType::get({laneCount, static_cast<int64_t>(crossbarSize.getValue())},
|
||||
shapeInfo->outType.getElementType());
|
||||
|
||||
Reference in New Issue
Block a user