extend operation support for conv and gemm
add more tests in validation
This commit is contained in:
@@ -5,9 +5,11 @@ add_public_tablegen_target(ONNXToSpatialIncGen)
|
|||||||
add_onnx_mlir_library(OMONNXToSpatial
|
add_onnx_mlir_library(OMONNXToSpatial
|
||||||
Math/Gemm.cpp
|
Math/Gemm.cpp
|
||||||
Math/Conv.cpp
|
Math/Conv.cpp
|
||||||
|
Math/MatMul.cpp
|
||||||
NN/Pooling.cpp
|
NN/Pooling.cpp
|
||||||
NN/ReduceMean.cpp
|
NN/ReduceMean.cpp
|
||||||
Tensor/ONNXConcatToTensorConcat.cpp
|
Tensor/ONNXConcatToTensorConcat.cpp
|
||||||
|
Tensor/ONNXReshapeToTensorReshape.cpp
|
||||||
Tensor/RemoveUnusedHelperOps.cpp
|
Tensor/RemoveUnusedHelperOps.cpp
|
||||||
Utils/SpatialReducer.cpp
|
Utils/SpatialReducer.cpp
|
||||||
Utils/WeightSubdivider.cpp
|
Utils/WeightSubdivider.cpp
|
||||||
|
|||||||
@@ -130,19 +130,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
});
|
});
|
||||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
|
||||||
// Reshape bias [numChannelsOut] -> [1, numChannelsOut] for Gemm C row-broadcasting, or use none
|
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||||
Value gemmC;
|
Value gemmC;
|
||||||
if (hasB) {
|
if (hasB)
|
||||||
auto biasType = RankedTensorType::get({1, numChannelsOut}, cast<RankedTensorType>(b.getType()).getElementType());
|
gemmC = b;
|
||||||
gemmC = tensor::ExpandShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
biasType,
|
|
||||||
b,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0, 1}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
else
|
else
|
||||||
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,38 @@ namespace {
|
|||||||
|
|
||||||
constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
||||||
|
|
||||||
|
static FailureOr<Value> materializeScaledConstantTensor(Value value,
|
||||||
|
float factor,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (factor == 1.0f)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
|
||||||
|
if (!constantOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto denseAttr = dyn_cast<DenseFPElementsAttr>(constantOp.getValue());
|
||||||
|
if (!denseAttr)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<APFloat> scaledValues;
|
||||||
|
scaledValues.reserve(denseAttr.getNumElements());
|
||||||
|
APFloat scale(factor);
|
||||||
|
bool hadFailure = false;
|
||||||
|
for (const APFloat& originalValue : denseAttr.getValues<APFloat>()) {
|
||||||
|
APFloat scaledValue(originalValue);
|
||||||
|
if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven))
|
||||||
|
hadFailure = true;
|
||||||
|
scaledValues.push_back(std::move(scaledValue));
|
||||||
|
}
|
||||||
|
if (hadFailure)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
|
||||||
|
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
@@ -74,10 +106,25 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
if (numOutRows <= 1)
|
if (numOutRows <= 1)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
||||||
|
if (failed(scaledB))
|
||||||
|
return failure();
|
||||||
|
b = *scaledB;
|
||||||
|
|
||||||
RankedTensorType cType = nullptr;
|
RankedTensorType cType = nullptr;
|
||||||
bool cHasNumOutRows = false;
|
bool cHasNumOutRows = false;
|
||||||
if (hasC) {
|
if (hasC) {
|
||||||
|
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||||
|
if (failed(scaledC))
|
||||||
|
return failure();
|
||||||
|
c = *scaledC;
|
||||||
cType = cast<RankedTensorType>(c.getType());
|
cType = cast<RankedTensorType>(c.getType());
|
||||||
|
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||||
|
if (cType.getRank() == 1) {
|
||||||
|
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||||
|
c = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
|
||||||
|
cType = expandedType;
|
||||||
|
}
|
||||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||||
}
|
}
|
||||||
@@ -112,8 +159,8 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
aSlice,
|
aSlice,
|
||||||
b,
|
b,
|
||||||
cSlice,
|
cSlice,
|
||||||
gemmOp.getAlphaAttr(),
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
gemmOp.getBetaAttr(),
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
gemmOp.getTransAAttr(),
|
gemmOp.getTransAAttr(),
|
||||||
gemmOp.getTransBAttr());
|
gemmOp.getTransBAttr());
|
||||||
gemvOps.push_back(gemvOp.getY());
|
gemvOps.push_back(gemvOp.getY());
|
||||||
@@ -158,6 +205,12 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||||
if (hasC) {
|
if (hasC) {
|
||||||
cType = cast<RankedTensorType>(c.getType());
|
cType = cast<RankedTensorType>(c.getType());
|
||||||
|
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||||
|
if (cType.getRank() == 1) {
|
||||||
|
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||||
|
c = tensor::ExpandShapeOp::create(rewriter, gemmLoc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
|
||||||
|
cType = expandedType;
|
||||||
|
}
|
||||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,19 +230,24 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
auto bShape = bType.getShape();
|
auto bShape = bType.getShape();
|
||||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||||
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
bType = cast<RankedTensorType>(b.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (alpha != 1.0f) {
|
if (alpha != 1.0f) {
|
||||||
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
|
auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc);
|
||||||
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
|
if (failed(scaledB))
|
||||||
auto alphaTensor = arith::ConstantOp::create(rewriter, gemmLoc, alphaTensorType, alphaTensorValue);
|
return failure();
|
||||||
a = spatial::SpatVMulOp::create(rewriter, gemmLoc, a.getType(), a, alphaTensor);
|
b = *scaledB;
|
||||||
|
bType = cast<RankedTensorType>(b.getType());
|
||||||
|
alpha = 1.0f;
|
||||||
}
|
}
|
||||||
if (hasC && beta != 1.0f) {
|
if (hasC && beta != 1.0f) {
|
||||||
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
|
auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc);
|
||||||
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
|
if (failed(scaledC))
|
||||||
auto betaTensor = arith::ConstantOp::create(rewriter, gemmLoc, betaTensorType, betaTensorValue);
|
return failure();
|
||||||
c = spatial::SpatVMulOp::create(rewriter, gemmLoc, c.getType(), c, betaTensor);
|
c = *scaledC;
|
||||||
|
cType = cast<RankedTensorType>(c.getType());
|
||||||
|
beta = 1.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||||
|
|||||||
108
src/PIM/Conversion/ONNXToSpatial/Math/MatMul.cpp
Normal file
108
src/PIM/Conversion/ONNXToSpatial/Math/MatMul.cpp
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
||||||
|
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
||||||
|
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
||||||
|
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||||
|
|| !outType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t batch = rhsType.getDimSize(0);
|
||||||
|
const int64_t k = rhsType.getDimSize(1);
|
||||||
|
const int64_t n = rhsType.getDimSize(2);
|
||||||
|
const int64_t m = lhsType.getDimSize(0);
|
||||||
|
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|
||||||
|
|| outType.getDimSize(2) != n)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Location loc = matmulOp.getLoc();
|
||||||
|
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
|
||||||
|
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
|
||||||
|
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
|
||||||
|
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
|
||||||
|
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
|
||||||
|
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
|
||||||
|
|
||||||
|
Value lhsTransposed =
|
||||||
|
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
|
||||||
|
SmallVector<Value> gemmRows;
|
||||||
|
gemmRows.reserve(batch * n);
|
||||||
|
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
||||||
|
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
|
||||||
|
SmallVector<OpFoldResult> offsets = {
|
||||||
|
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
|
||||||
|
SmallVector<OpFoldResult> strides = {
|
||||||
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
Value rhsSlice =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
|
||||||
|
Value rhsRow = tensor::CollapseShapeOp::create(
|
||||||
|
rewriter, loc, rhsRowType, rhsSlice, SmallVector<ReassociationIndices>{{0}, {1, 2}});
|
||||||
|
|
||||||
|
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
gemmRowType,
|
||||||
|
rhsRow,
|
||||||
|
lhsTransposed,
|
||||||
|
none,
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getBoolAttr(false),
|
||||||
|
rewriter.getBoolAttr(false));
|
||||||
|
gemmRows.push_back(gemmOp.getY());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto concatComputeOp =
|
||||||
|
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector<Value>(), gemmRows);
|
||||||
|
|
||||||
|
auto* concatBlock = new Block();
|
||||||
|
for (Value gemmRow : gemmRows)
|
||||||
|
concatBlock->addArgument(gemmRow.getType(), loc);
|
||||||
|
concatComputeOp.getBody().push_back(concatBlock);
|
||||||
|
rewriter.setInsertionPointToStart(concatBlock);
|
||||||
|
|
||||||
|
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments());
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(concatComputeOp);
|
||||||
|
Value gemmOut = concatComputeOp.getResult(0);
|
||||||
|
Value gemmExpanded = tensor::ExpandShapeOp::create(
|
||||||
|
rewriter, loc, gemmExpandedType, gemmOut, SmallVector<ReassociationIndices>{{0, 1}, {2}});
|
||||||
|
Value result = ONNXTransposeOp::create(
|
||||||
|
rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||||
|
|
||||||
|
rewriter.replaceOp(matmulOp, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
patterns.insert<MatMulRank3ToGemm>(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -15,6 +15,10 @@ def onnxToArithConstantOp : Pat<
|
|||||||
|
|
||||||
// ONNXMatMulOp to ONNXGemmOp patterns
|
// ONNXMatMulOp to ONNXGemmOp patterns
|
||||||
|
|
||||||
|
def IsRank2Result: Constraint<
|
||||||
|
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
|
||||||
|
"Result is rank 2">;
|
||||||
|
|
||||||
def matMulAddToGemmPattern : Pat<
|
def matMulAddToGemmPattern : Pat<
|
||||||
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
|
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
|
||||||
(ONNXGemmOp $A, $B, $C,
|
(ONNXGemmOp $A, $B, $C,
|
||||||
@@ -22,19 +26,21 @@ def matMulAddToGemmPattern : Pat<
|
|||||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||||
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
||||||
)
|
),
|
||||||
|
[(IsRank2Result $matmulres)]
|
||||||
>;
|
>;
|
||||||
|
|
||||||
def matMulToGemmPattern : Pat<
|
def matMulToGemmPattern : Pat<
|
||||||
(ONNXMatMulOp:$matmulres $A, $B),
|
(ONNXMatMulOp:$matmulres $A, $B),
|
||||||
(
|
(
|
||||||
ONNXGemmOp $A, $B,
|
ONNXGemmOp $A, $B,
|
||||||
/* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
|
/* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
|
||||||
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
|
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
|
||||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||||
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
||||||
)
|
),
|
||||||
|
[(IsRank2Result $matmulres)]
|
||||||
>;
|
>;
|
||||||
|
|
||||||
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
|
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
@@ -56,6 +55,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
|
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
|
||||||
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
|
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
|
||||||
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
|
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
|
||||||
|
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
|
||||||
|
|
||||||
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
||||||
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
||||||
@@ -74,7 +74,9 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
|
|
||||||
ConversionTarget target(*ctx);
|
ConversionTarget target(*ctx);
|
||||||
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
|
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
|
||||||
target.addIllegalOp<ONNXMatMulOp>();
|
target.addDynamicallyLegalOp<ONNXMatMulOp>([](ONNXMatMulOp op) {
|
||||||
|
return cast<ShapedType>(op.getY().getType()).getRank() != 2;
|
||||||
|
});
|
||||||
target.addIllegalOp<ONNXGemmOp>();
|
target.addIllegalOp<ONNXGemmOp>();
|
||||||
target.addIllegalOp<ONNXConvOp>();
|
target.addIllegalOp<ONNXConvOp>();
|
||||||
target.addIllegalOp<ONNXLRNOp>();
|
target.addIllegalOp<ONNXLRNOp>();
|
||||||
@@ -83,6 +85,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
target.addIllegalOp<ONNXConcatOp>();
|
target.addIllegalOp<ONNXConcatOp>();
|
||||||
target.addIllegalOp<ONNXSoftmaxOp>();
|
target.addIllegalOp<ONNXSoftmaxOp>();
|
||||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||||
|
target.addIllegalOp<ONNXReshapeOp>();
|
||||||
|
|
||||||
RewritePatternSet patterns(ctx);
|
RewritePatternSet patterns(ctx);
|
||||||
patterns.add<removeLRNPattern>(ctx);
|
patterns.add<removeLRNPattern>(ctx);
|
||||||
@@ -90,6 +93,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
populateConvOpPatterns(patterns, ctx);
|
populateConvOpPatterns(patterns, ctx);
|
||||||
populatePoolingTilingPattern(patterns, ctx);
|
populatePoolingTilingPattern(patterns, ctx);
|
||||||
populateOnnxGemmOpPatterns(patterns, ctx);
|
populateOnnxGemmOpPatterns(patterns, ctx);
|
||||||
|
populateReshapeConversionPattern(patterns, ctx);
|
||||||
|
|
||||||
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
||||||
populateReduceMeanConversionPattern(patterns, ctx);
|
populateReduceMeanConversionPattern(patterns, ctx);
|
||||||
|
|||||||
@@ -7,12 +7,16 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|||||||
@@ -0,0 +1,121 @@
|
|||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||||
|
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
|
||||||
|
ArrayRef<int64_t> resultShape,
|
||||||
|
SmallVector<ReassociationIndices>& reassociation) {
|
||||||
|
reassociation.clear();
|
||||||
|
|
||||||
|
size_t sourceIdx = 0;
|
||||||
|
size_t resultIdx = 0;
|
||||||
|
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
|
||||||
|
int64_t sourceProduct = sourceShape[sourceIdx];
|
||||||
|
int64_t resultProduct = resultShape[resultIdx];
|
||||||
|
|
||||||
|
ReassociationIndices group;
|
||||||
|
group.push_back(sourceIdx);
|
||||||
|
while (sourceProduct != resultProduct) {
|
||||||
|
if (sourceProduct > resultProduct)
|
||||||
|
return false;
|
||||||
|
sourceIdx++;
|
||||||
|
if (sourceIdx >= sourceShape.size())
|
||||||
|
return false;
|
||||||
|
group.push_back(sourceIdx);
|
||||||
|
sourceProduct *= sourceShape[sourceIdx];
|
||||||
|
}
|
||||||
|
|
||||||
|
reassociation.push_back(group);
|
||||||
|
sourceIdx++;
|
||||||
|
resultIdx++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
|
||||||
|
ArrayRef<int64_t> resultShape,
|
||||||
|
SmallVector<ReassociationIndices>& reassociation) {
|
||||||
|
reassociation.clear();
|
||||||
|
|
||||||
|
size_t sourceIdx = 0;
|
||||||
|
size_t resultIdx = 0;
|
||||||
|
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
|
||||||
|
int64_t sourceProduct = sourceShape[sourceIdx];
|
||||||
|
int64_t resultProduct = resultShape[resultIdx];
|
||||||
|
|
||||||
|
ReassociationIndices group;
|
||||||
|
group.push_back(resultIdx);
|
||||||
|
while (resultProduct != sourceProduct) {
|
||||||
|
if (resultProduct > sourceProduct)
|
||||||
|
return false;
|
||||||
|
resultIdx++;
|
||||||
|
if (resultIdx >= resultShape.size())
|
||||||
|
return false;
|
||||||
|
group.push_back(resultIdx);
|
||||||
|
resultProduct *= resultShape[resultIdx];
|
||||||
|
}
|
||||||
|
|
||||||
|
reassociation.push_back(group);
|
||||||
|
sourceIdx++;
|
||||||
|
resultIdx++;
|
||||||
|
}
|
||||||
|
|
||||||
|
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ONNXReshapeToTensorReshape : OpConversionPattern<ONNXReshapeOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp,
|
||||||
|
ONNXReshapeOpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const override {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
|
||||||
|
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (sourceType == resultType) {
|
||||||
|
rewriter.replaceOp(reshapeOp, adaptor.getData());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<ReassociationIndices> reassociation;
|
||||||
|
if (sourceType.getRank() > resultType.getRank()
|
||||||
|
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sourceType.getRank() < resultType.getRank()
|
||||||
|
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
patterns.insert<ONNXReshapeToTensorReshape>(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
@@ -79,8 +80,31 @@ private:
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static bool isChannelUseChainOp(Operation* op) {
|
static bool isChannelUseChainOp(Operation* op) {
|
||||||
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>(
|
return isa<tensor::ExtractSliceOp,
|
||||||
op);
|
tensor::CollapseShapeOp,
|
||||||
|
tensor::ExpandShapeOp,
|
||||||
|
tensor::CastOp,
|
||||||
|
tosa::ReshapeOp,
|
||||||
|
pim::PimTransposeOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||||
|
for (Value operand : op->getOperands()) {
|
||||||
|
if (mapping.lookupOrNull(operand))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Operation* definingOp = operand.getDefiningOp();
|
||||||
|
if (!definingOp)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||||
|
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||||
|
mapping.map(originalResult, newResult);
|
||||||
|
rewriter.setInsertionPointAfter(clonedOp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t countComputeLeafUsers(Value value) {
|
static size_t countComputeLeafUsers(Value value) {
|
||||||
@@ -204,6 +228,56 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
OpOperand& resultUse = *resultUses.begin();
|
OpOperand& resultUse = *resultUses.begin();
|
||||||
Operation* resultUser = resultUse.getOwner();
|
Operation* resultUser = resultUse.getOwner();
|
||||||
|
|
||||||
|
if (isChannelUseChainOp(resultUser)) {
|
||||||
|
SmallVector<Operation*> returnChain;
|
||||||
|
Value chainedValue = result;
|
||||||
|
Operation* chainUser = resultUser;
|
||||||
|
|
||||||
|
while (isChannelUseChainOp(chainUser)) {
|
||||||
|
returnChain.push_back(chainUser);
|
||||||
|
auto chainUses = chainUser->getResult(0).getUses();
|
||||||
|
if (rangeLength(chainUses) != 1)
|
||||||
|
break;
|
||||||
|
chainedValue = chainUser->getResult(0);
|
||||||
|
chainUser = chainUses.begin()->getOwner();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<func::ReturnOp>(chainUser)) {
|
||||||
|
size_t resultIndexInReturn = chainedValue.getUses().begin()->getOperandNumber();
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(yieldOp);
|
||||||
|
IRMapping mapping;
|
||||||
|
mapping.map(result, yieldValue);
|
||||||
|
|
||||||
|
Value storedValue = yieldValue;
|
||||||
|
for (Operation* op : returnChain) {
|
||||||
|
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||||
|
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||||
|
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||||
|
mapping.map(originalResult, newResult);
|
||||||
|
storedValue = clonedOp->getResult(0);
|
||||||
|
rewriter.setInsertionPointAfter(clonedOp);
|
||||||
|
markOpToRemove(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto storedType = cast<ShapedType>(storedValue.getType());
|
||||||
|
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
|
Value outputTensor = outputTensors[resultIndexInReturn];
|
||||||
|
if (auto storedOp = storedValue.getDefiningOp())
|
||||||
|
rewriter.setInsertionPointAfter(storedOp);
|
||||||
|
PimMemCopyDevToHostOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
outputTensor.getType(),
|
||||||
|
outputTensor,
|
||||||
|
storedValue,
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(0),
|
||||||
|
rewriter.getI32IntegerAttr(storedType.getNumElements() * elementSize));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (isa<func::ReturnOp>(resultUser)) {
|
if (isa<func::ReturnOp>(resultUser)) {
|
||||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
@@ -493,6 +567,7 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
|||||||
IRMapping mapping;
|
IRMapping mapping;
|
||||||
mapping.map(channelSourceOp, receivedValue);
|
mapping.map(channelSourceOp, receivedValue);
|
||||||
for (Operation* op : llvm::reverse(clonedChain)) {
|
for (Operation* op : llvm::reverse(clonedChain)) {
|
||||||
|
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||||
mapping.map(originalResult, newResult);
|
mapping.map(originalResult, newResult);
|
||||||
|
|||||||
@@ -30,6 +30,24 @@ static Value stripMemRefCasts(Value value) {
|
|||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value stripMemRefViewOps(Value value) {
|
||||||
|
while (true) {
|
||||||
|
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||||
|
value = castOp.getSource();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||||
|
value = collapseOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||||
|
value = expandOp.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||||
Location loc,
|
Location loc,
|
||||||
MemRefType globalType,
|
MemRefType globalType,
|
||||||
@@ -204,6 +222,7 @@ struct StaticSubviewInfo {
|
|||||||
};
|
};
|
||||||
|
|
||||||
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||||
|
value = stripMemRefViewOps(value);
|
||||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||||
if (!subviewOp)
|
if (!subviewOp)
|
||||||
return failure();
|
return failure();
|
||||||
@@ -321,6 +340,77 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
|
auto srcSubview = getStaticSubviewInfo(copyOp.getHostSrc());
|
||||||
|
auto dstSubview = getStaticSubviewInfo(copyOp.getDeviceDst());
|
||||||
|
const bool splitSrc = succeeded(srcSubview)
|
||||||
|
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
||||||
|
const bool splitDst = succeeded(dstSubview)
|
||||||
|
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
||||||
|
if (!splitSrc && !splitDst)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceType = dyn_cast<MemRefType>(copyOp.getHostSrc().getType());
|
||||||
|
auto dstType = dyn_cast<MemRefType>(copyOp.getDeviceDst().getType());
|
||||||
|
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (sourceType.getElementType() != dstType.getElementType())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||||
|
return failure();
|
||||||
|
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||||
|
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
||||||
|
if (elementByteWidth <= 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||||
|
if (copyOp.getSize() != totalBytes)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||||
|
if (sliceBytes <= 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||||
|
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||||
|
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(copyOp);
|
||||||
|
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||||
|
SmallVector<int64_t> outerIndices =
|
||||||
|
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||||
|
const int64_t srcByteOffset = copyOp.getHostSrcOffset()
|
||||||
|
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
||||||
|
: linearIndex * sliceBytes);
|
||||||
|
const int64_t dstByteOffset = copyOp.getDeviceDstOffset()
|
||||||
|
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
||||||
|
: linearIndex * sliceBytes);
|
||||||
|
pim::PimMemCopyHostToDevOp::create(
|
||||||
|
rewriter,
|
||||||
|
copyOp.getLoc(),
|
||||||
|
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
||||||
|
splitDst ? dstSubview->source : copyOp.getDeviceDst(),
|
||||||
|
splitSrc ? srcSubview->source : copyOp.getHostSrc(),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
|
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
|
||||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||||
if (!allocType || !allocType.hasStaticShape())
|
if (!allocType || !allocType.hasStaticShape())
|
||||||
@@ -578,6 +668,170 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||||
|
// Only match top-level memcp (not inside pim.core)
|
||||||
|
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// dst must be an alloc with static shape
|
||||||
|
auto allocOp = copyOp.getDst().getDefiningOp<memref::AllocOp>();
|
||||||
|
if (!allocOp)
|
||||||
|
return failure();
|
||||||
|
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||||
|
if (!allocType || !allocType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// The copy must cover the full destination (offsets both zero)
|
||||||
|
if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Resolve the source through an optional subview to a get_global
|
||||||
|
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
||||||
|
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc());
|
||||||
|
|
||||||
|
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||||
|
if (!moduleOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
||||||
|
if (failed(denseAttr))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Build the folded dense attribute
|
||||||
|
DenseElementsAttr foldedAttr;
|
||||||
|
if (succeeded(srcSubview)) {
|
||||||
|
// Extract the sub-tensor from the source constant
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
if (llvm::any_of(srcSubview->strides, [](int64_t s) { return s != 1; }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||||
|
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||||
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
|
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||||
|
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||||
|
SmallVector<Attribute> resultValues(numResultElements);
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||||
|
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||||
|
SmallVector<int64_t> sourceIndices;
|
||||||
|
sourceIndices.reserve(resultIndices.size());
|
||||||
|
for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices))
|
||||||
|
sourceIndices.push_back(off + idx);
|
||||||
|
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
|
||||||
|
resultValues[i] = sourceValues[srcLinear];
|
||||||
|
}
|
||||||
|
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Direct copy from a global — just reuse its dense attribute
|
||||||
|
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||||
|
if (resultTensorType != denseAttr->getType())
|
||||||
|
return failure();
|
||||||
|
foldedAttr = *denseAttr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the alloc's remaining users are supported ops.
|
||||||
|
bool allLiveUsersAreCores = true;
|
||||||
|
for (Operation* user : allocOp->getUsers()) {
|
||||||
|
if (user == copyOp)
|
||||||
|
continue;
|
||||||
|
if (isa<memref::DeallocOp>(user))
|
||||||
|
continue;
|
||||||
|
if (isa<pim::PimCoreOp>(user))
|
||||||
|
continue;
|
||||||
|
if (isa<memref::SubViewOp>(user)) {
|
||||||
|
allLiveUsersAreCores = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp");
|
||||||
|
if (allLiveUsersAreCores)
|
||||||
|
markWeightAlways(newGlobal);
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(allocOp);
|
||||||
|
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
||||||
|
if (allLiveUsersAreCores)
|
||||||
|
markWeightAlways(newGetGlobal);
|
||||||
|
|
||||||
|
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
|
||||||
|
rewriter.eraseOp(copyOp);
|
||||||
|
if (allocOp.use_empty())
|
||||||
|
rewriter.eraseOp(allocOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
|
||||||
|
// Only handle subviews whose users are all pim.core ops.
|
||||||
|
if (subviewOp.use_empty())
|
||||||
|
return failure();
|
||||||
|
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Source must resolve to a constant get_global.
|
||||||
|
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
|
||||||
|
if (!moduleOp)
|
||||||
|
return failure();
|
||||||
|
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
|
||||||
|
if (failed(denseAttr))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Static subview info.
|
||||||
|
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
|
||||||
|
if (failed(subviewInfo))
|
||||||
|
return failure();
|
||||||
|
if (llvm::any_of(subviewInfo->strides, [](int64_t s) { return s != 1; }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Build the contiguous result type.
|
||||||
|
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
||||||
|
auto resultMemRefType = MemRefType::get(
|
||||||
|
SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
||||||
|
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
|
||||||
|
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||||
|
|
||||||
|
// Extract the sub-tensor.
|
||||||
|
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
|
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||||
|
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||||
|
SmallVector<Attribute> resultValues(numResultElements);
|
||||||
|
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||||
|
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||||
|
SmallVector<int64_t> sourceIndices;
|
||||||
|
sourceIndices.reserve(resultIndices.size());
|
||||||
|
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
|
||||||
|
sourceIndices.push_back(off + idx);
|
||||||
|
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
||||||
|
}
|
||||||
|
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||||
|
|
||||||
|
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
|
||||||
|
markWeightAlways(newGlobal);
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(subviewOp);
|
||||||
|
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
|
||||||
|
markWeightAlways(newGetGlobal);
|
||||||
|
|
||||||
|
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
||||||
|
|
||||||
@@ -591,7 +845,13 @@ struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPas
|
|||||||
for (RegisteredOperationName op : context->getRegisteredOperations())
|
for (RegisteredOperationName op : context->getRegisteredOperations())
|
||||||
op.getCanonicalizationPatterns(owningPatterns, context);
|
op.getCanonicalizationPatterns(owningPatterns, context);
|
||||||
owningPatterns
|
owningPatterns
|
||||||
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, RewriteCoreSubviewCopyPattern>(
|
.add<FoldConstantTransposePattern,
|
||||||
|
FoldConstantAllocPattern,
|
||||||
|
FoldConstantCoreMapPattern,
|
||||||
|
RewriteCoreSubviewCopyPattern,
|
||||||
|
RewriteHostSubviewLoadPattern,
|
||||||
|
FoldConstantMemCpPattern,
|
||||||
|
FoldConstantCoreSubviewPattern>(
|
||||||
context);
|
context);
|
||||||
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
||||||
return success();
|
return success();
|
||||||
|
|||||||
47
validation/operations/README.md
Normal file
47
validation/operations/README.md
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Validation Operations
|
||||||
|
|
||||||
|
ONNX test models used by `validate.py` to verify the Raptor compiler + PIM simulator pipeline.
|
||||||
|
|
||||||
|
Generated tests can be regenerated with:
|
||||||
|
```
|
||||||
|
python3 validation/operations/gen_tests.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Conv
|
||||||
|
|
||||||
|
| Test | Directory | Input | Output | Kernel | Stride | Padding | Bias | Notes |
|
||||||
|
|------|-----------|-------|--------|--------|--------|---------|------|-------|
|
||||||
|
| Simple | `conv/simple` | [1,3,3,3] | [1,1,2,2] | 2x2 | 1 | none | no | Basic conv, hand-crafted |
|
||||||
|
| With constant | `conv/with_constant` | [1,3,3,3] | [1,1,3,3] | 2x2 | 1 | SAME_UPPER | yes | Hand-crafted, constant weight+bias |
|
||||||
|
| Batch 2 | `conv/batch_2` | [2,3,3,3] | [2,1,3,3] | 2x2 | 1 | SAME_UPPER | yes | Batched input |
|
||||||
|
| Kernel 3x3 | `conv/kernel_3x3` | [1,1,5,5] | [1,1,3,3] | 3x3 | 1 | none | no | Larger kernel |
|
||||||
|
| Stride 2 | `conv/stride_2` | [1,1,6,6] | [1,1,2,2] | 3x3 | 2 | none | no | Strided convolution |
|
||||||
|
| Multi channel | `conv/multi_channel` | [1,3,5,5] | [1,4,3,3] | 3x3 | 1 | none | no | 3 in channels, 4 out channels |
|
||||||
|
| Pointwise 1x1 | `conv/pointwise_1x1` | [1,8,4,4] | [1,4,4,4] | 1x1 | 1 | none | no | Channel mixing |
|
||||||
|
| SAME padding 3x3 | `conv/same_padding_3x3` | [1,1,5,5] | [1,1,5,5] | 3x3 | 1 | SAME_UPPER | no | Spatial dims preserved |
|
||||||
|
| Explicit padding | `conv/explicit_padding` | [1,1,4,4] | [1,1,4,4] | 3x3 | 1 | [1,1,1,1] | no | Symmetric explicit pads |
|
||||||
|
| With bias 3x3 | `conv/with_bias_3x3` | [1,3,5,5] | [1,2,3,3] | 3x3 | 1 | none | yes | Multi-channel with bias |
|
||||||
|
| Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input |
|
||||||
|
|
||||||
|
## Gemm
|
||||||
|
|
||||||
|
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|
||||||
|
|------|-----------|-----------|------------|--------|--------|-------|------|------|-------|
|
||||||
|
| Default | `gemm/` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Hand-crafted, square weights |
|
||||||
|
| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N |
|
||||||
|
| With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector |
|
||||||
|
| transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight |
|
||||||
|
| Alpha/beta | `gemm/alpha_beta` | [4,64] | [64,64] | [4,64] | no | 0.5 | 0.25 | [64] | Scaled matmul + bias |
|
||||||
|
| Small | `gemm/small` | [2,8] | [8,4] | [2,4] | no | 1 | 1 | no | Tiny matrices |
|
||||||
|
| Large | `gemm/large` | [8,256] | [256,128] | [8,128] | no | 1 | 1 | no | Larger matrices |
|
||||||
|
| transB + bias | `gemm/transB_with_bias` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | [64] | Combined |
|
||||||
|
|
||||||
|
## Gemv
|
||||||
|
|
||||||
|
| Test | Directory | Input | W (weight) | Output | Bias | Notes |
|
||||||
|
|------|-----------|-------|------------|--------|------|-------|
|
||||||
|
| Simple | `gemv/simple` | [1,132] | [132,132] | [1,132] | no | Single-sample matmul |
|
||||||
|
| Constant | `gemv/constant` | _(none)_ | [132,132] | [1,132] | no | All inputs constant |
|
||||||
|
| Homogeneous const | `gemv/with_homogeneous_constant` | [1,132] | [132,132] | [1,132] | [1,132] | Bias matches output shape |
|
||||||
|
| Heterogeneous const | `gemv/with_heterogeneous_constant` | [1,132] | [132,132] | [1,132] | [1,132] | Different constant pattern |
|
||||||
|
| Scalar const | `gemv/with_scalar_constant` | [1,132] | [132,132] | [1,132] | [1,1] | Scalar bias, broadcast |
|
||||||
BIN
validation/operations/conv/batch_2/conv_batch_2.onnx
Normal file
BIN
validation/operations/conv/batch_2/conv_batch_2.onnx
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
validation/operations/conv/kernel_3x3/conv_kernel_3x3.onnx
Normal file
BIN
validation/operations/conv/kernel_3x3/conv_kernel_3x3.onnx
Normal file
Binary file not shown.
BIN
validation/operations/conv/large_spatial/conv_large_spatial.onnx
Normal file
BIN
validation/operations/conv/large_spatial/conv_large_spatial.onnx
Normal file
Binary file not shown.
BIN
validation/operations/conv/multi_channel/conv_multi_channel.onnx
Normal file
BIN
validation/operations/conv/multi_channel/conv_multi_channel.onnx
Normal file
Binary file not shown.
BIN
validation/operations/conv/pointwise_1x1/conv_1x1.onnx
Normal file
BIN
validation/operations/conv/pointwise_1x1/conv_1x1.onnx
Normal file
Binary file not shown.
Binary file not shown.
BIN
validation/operations/conv/stride_2/conv_stride_2.onnx
Normal file
BIN
validation/operations/conv/stride_2/conv_stride_2.onnx
Normal file
Binary file not shown.
BIN
validation/operations/conv/with_bias_3x3/conv_with_bias_3x3.onnx
Normal file
BIN
validation/operations/conv/with_bias_3x3/conv_with_bias_3x3.onnx
Normal file
Binary file not shown.
BIN
validation/operations/gemm/alpha_beta/gemm_alpha_beta.onnx
Normal file
BIN
validation/operations/gemm/alpha_beta/gemm_alpha_beta.onnx
Normal file
Binary file not shown.
BIN
validation/operations/gemm/large/gemm_large.onnx
Normal file
BIN
validation/operations/gemm/large/gemm_large.onnx
Normal file
Binary file not shown.
BIN
validation/operations/gemm/non_square/gemm_non_square.onnx
Normal file
BIN
validation/operations/gemm/non_square/gemm_non_square.onnx
Normal file
Binary file not shown.
BIN
validation/operations/gemm/small/gemm_small.onnx
Normal file
BIN
validation/operations/gemm/small/gemm_small.onnx
Normal file
Binary file not shown.
BIN
validation/operations/gemm/transB/gemm_transB.onnx
Normal file
BIN
validation/operations/gemm/transB/gemm_transB.onnx
Normal file
Binary file not shown.
Binary file not shown.
BIN
validation/operations/gemm/with_bias/gemm_with_bias.onnx
Normal file
BIN
validation/operations/gemm/with_bias/gemm_with_bias.onnx
Normal file
Binary file not shown.
276
validation/operations/gen_tests.py
Normal file
276
validation/operations/gen_tests.py
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Generate ONNX test models for validating GEMM and Conv implementations."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
from onnx import helper, TensorProto, numpy_helper
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
OPERATIONS_DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(model, directory, filename):
|
||||||
|
"""Save an ONNX model, creating the directory if needed."""
|
||||||
|
d = OPERATIONS_DIR / directory
|
||||||
|
d.mkdir(parents=True, exist_ok=True)
|
||||||
|
path = d / filename
|
||||||
|
onnx.checker.check_model(model)
|
||||||
|
onnx.save(model, str(path))
|
||||||
|
print(f" {path.relative_to(OPERATIONS_DIR)}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GEMM tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def gemm_non_square():
|
||||||
|
"""GEMM with non-square weight matrix: [B, K] @ [K, N], K != N."""
|
||||||
|
B, K, N = 4, 128, 64
|
||||||
|
W = numpy_helper.from_array(np.random.default_rng(42).uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||||
|
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "gemm_non_square", [A], [Y], initializer=[W])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "gemm/non_square", "gemm_non_square.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_with_bias():
|
||||||
|
"""GEMM with bias: Y = A @ W + C."""
|
||||||
|
B, K, N = 4, 128, 128
|
||||||
|
rng = np.random.default_rng(43)
|
||||||
|
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||||
|
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||||
|
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "gemm_with_bias", [A], [Y], initializer=[W, C])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "gemm/with_bias", "gemm_with_bias.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_transB():
|
||||||
|
"""GEMM with transB=1: Y = A @ W^T."""
|
||||||
|
B, K, N = 4, 128, 64
|
||||||
|
rng = np.random.default_rng(44)
|
||||||
|
# W stored as [N, K], transposed during computation
|
||||||
|
W = numpy_helper.from_array(rng.uniform(-1, 1, (N, K)).astype(np.float32), name="W")
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||||
|
node = helper.make_node("Gemm", ["A", "W"], ["Y"], transB=1)
|
||||||
|
graph = helper.make_graph([node], "gemm_transB", [A], [Y], initializer=[W])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "gemm/transB", "gemm_transB.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_alpha_beta():
|
||||||
|
"""GEMM with alpha and beta: Y = 0.5 * A @ W + 0.25 * C."""
|
||||||
|
B, K, N = 4, 64, 64
|
||||||
|
rng = np.random.default_rng(45)
|
||||||
|
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||||
|
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||||
|
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"], alpha=0.5, beta=0.25)
|
||||||
|
graph = helper.make_graph([node], "gemm_alpha_beta", [A], [Y], initializer=[W, C])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "gemm/alpha_beta", "gemm_alpha_beta.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_small():
|
||||||
|
"""Small GEMM: [2, 8] @ [8, 4]."""
|
||||||
|
B, K, N = 2, 8, 4
|
||||||
|
rng = np.random.default_rng(46)
|
||||||
|
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||||
|
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "gemm_small", [A], [Y], initializer=[W])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "gemm/small", "gemm_small.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_large():
|
||||||
|
"""Larger GEMM: [8, 256] @ [256, 128]."""
|
||||||
|
B, K, N = 8, 256, 128
|
||||||
|
rng = np.random.default_rng(47)
|
||||||
|
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||||
|
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "gemm_large", [A], [Y], initializer=[W])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "gemm/large", "gemm_large.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_transB_with_bias():
|
||||||
|
"""GEMM with transB and bias: Y = A @ W^T + C."""
|
||||||
|
B, K, N = 4, 128, 64
|
||||||
|
rng = np.random.default_rng(48)
|
||||||
|
W = numpy_helper.from_array(rng.uniform(-1, 1, (N, K)).astype(np.float32), name="W")
|
||||||
|
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||||
|
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"], transB=1)
|
||||||
|
graph = helper.make_graph([node], "gemm_transB_with_bias", [A], [Y], initializer=[W, C])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "gemm/transB_with_bias", "gemm_transB_with_bias.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Conv tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def conv_3x3_kernel():
|
||||||
|
"""Conv with 3x3 kernel, no padding."""
|
||||||
|
# Input: [1, 1, 5, 5], Kernel: [1, 1, 3, 3] -> Output: [1, 1, 3, 3]
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
|
||||||
|
W = numpy_helper.from_array(
|
||||||
|
np.random.default_rng(50).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||||
|
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||||
|
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||||
|
graph = helper.make_graph([node], "conv_3x3", [X], [Y], initializer=[W])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "conv/kernel_3x3", "conv_kernel_3x3.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def conv_stride2():
|
||||||
|
"""Conv with 3x3 kernel and stride 2."""
|
||||||
|
# Input: [1, 1, 6, 6], Kernel: [1, 1, 3, 3], stride 2 -> Output: [1, 1, 2, 2]
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 6, 6])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 2, 2])
|
||||||
|
W = numpy_helper.from_array(
|
||||||
|
np.random.default_rng(51).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||||
|
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||||
|
kernel_shape=[3, 3], strides=[2, 2], pads=[0, 0, 0, 0])
|
||||||
|
graph = helper.make_graph([node], "conv_stride2", [X], [Y], initializer=[W])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "conv/stride_2", "conv_stride_2.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def conv_multi_channel():
|
||||||
|
"""Conv with multiple input and output channels."""
|
||||||
|
# Input: [1, 3, 5, 5], Kernel: [4, 3, 3, 3] -> Output: [1, 4, 3, 3]
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 3, 3])
|
||||||
|
W = numpy_helper.from_array(
|
||||||
|
np.random.default_rng(52).uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
|
||||||
|
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||||
|
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||||
|
graph = helper.make_graph([node], "conv_multi_channel", [X], [Y], initializer=[W])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "conv/multi_channel", "conv_multi_channel.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def conv_1x1():
|
||||||
|
"""1x1 pointwise convolution (channel mixing)."""
|
||||||
|
# Input: [1, 8, 4, 4], Kernel: [4, 8, 1, 1] -> Output: [1, 4, 4, 4]
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 8, 4, 4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 4, 4])
|
||||||
|
W = numpy_helper.from_array(
|
||||||
|
np.random.default_rng(53).uniform(-1, 1, (4, 8, 1, 1)).astype(np.float32), name="W")
|
||||||
|
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||||
|
kernel_shape=[1, 1], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||||
|
graph = helper.make_graph([node], "conv_1x1", [X], [Y], initializer=[W])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "conv/pointwise_1x1", "conv_1x1.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def conv_same_padding_3x3():
|
||||||
|
"""Conv 3x3 with SAME_UPPER padding, preserving spatial dimensions."""
|
||||||
|
# Input: [1, 1, 5, 5], Kernel: [1, 1, 3, 3], SAME_UPPER -> Output: [1, 1, 5, 5]
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 5, 5])
|
||||||
|
W = numpy_helper.from_array(
|
||||||
|
np.random.default_rng(54).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||||
|
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||||
|
kernel_shape=[3, 3], strides=[1, 1], auto_pad="SAME_UPPER")
|
||||||
|
graph = helper.make_graph([node], "conv_same_3x3", [X], [Y], initializer=[W])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "conv/same_padding_3x3", "conv_same_padding_3x3.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def conv_explicit_padding():
|
||||||
|
"""Conv 3x3 with explicit asymmetric padding."""
|
||||||
|
# Input: [1, 1, 4, 4], Kernel: [1, 1, 3, 3], pads=[1,1,1,1] -> Output: [1, 1, 4, 4]
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 4, 4])
|
||||||
|
W = numpy_helper.from_array(
|
||||||
|
np.random.default_rng(55).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||||
|
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||||
|
kernel_shape=[3, 3], strides=[1, 1], pads=[1, 1, 1, 1])
|
||||||
|
graph = helper.make_graph([node], "conv_explicit_pad", [X], [Y], initializer=[W])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "conv/explicit_padding", "conv_explicit_padding.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def conv_with_bias_3x3():
|
||||||
|
"""Conv 3x3 with bias."""
|
||||||
|
# Input: [1, 3, 5, 5], Kernel: [2, 3, 3, 3], Bias: [2] -> Output: [1, 2, 3, 3]
|
||||||
|
rng = np.random.default_rng(56)
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 3, 3])
|
||||||
|
W = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 3, 3)).astype(np.float32), name="W")
|
||||||
|
B = numpy_helper.from_array(rng.uniform(-1, 1, (2,)).astype(np.float32), name="B")
|
||||||
|
node = helper.make_node("Conv", ["X", "W", "B"], ["Y"],
|
||||||
|
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||||
|
graph = helper.make_graph([node], "conv_with_bias_3x3", [X], [Y], initializer=[W, B])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "conv/with_bias_3x3", "conv_with_bias_3x3.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def conv_batch_2():
|
||||||
|
"""Batched conv (batch=2) with SAME_UPPER padding and bias."""
|
||||||
|
# Input: [2, 3, 3, 3], Kernel: [1, 3, 2, 2], Bias: [1] -> Output: [2, 1, 3, 3]
|
||||||
|
rng = np.random.default_rng(57)
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 3, 3])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 1, 3, 3])
|
||||||
|
W = numpy_helper.from_array(rng.uniform(-1, 1, (1, 3, 2, 2)).astype(np.float32), name="W")
|
||||||
|
B = numpy_helper.from_array(rng.uniform(-1, 1, (1,)).astype(np.float32), name="B")
|
||||||
|
node = helper.make_node("Conv", ["X", "W", "B"], ["Y"],
|
||||||
|
kernel_shape=[2, 2], strides=[1, 1], auto_pad="SAME_UPPER")
|
||||||
|
graph = helper.make_graph([node], "conv_batch_2", [X], [Y], initializer=[W, B])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "conv/batch_2", "conv_batch_2.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def conv_large_spatial():
|
||||||
|
"""Conv on larger spatial input: [1, 1, 8, 8] with 3x3 kernel."""
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 8, 8])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 6, 6])
|
||||||
|
W = numpy_helper.from_array(
|
||||||
|
np.random.default_rng(58).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||||
|
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||||
|
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||||
|
graph = helper.make_graph([node], "conv_large_spatial", [X], [Y], initializer=[W])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "conv/large_spatial", "conv_large_spatial.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Generating GEMM tests:")
|
||||||
|
gemm_non_square()
|
||||||
|
gemm_with_bias()
|
||||||
|
gemm_transB()
|
||||||
|
gemm_alpha_beta()
|
||||||
|
gemm_small()
|
||||||
|
gemm_large()
|
||||||
|
gemm_transB_with_bias()
|
||||||
|
|
||||||
|
print("\nGenerating Conv tests:")
|
||||||
|
conv_3x3_kernel()
|
||||||
|
conv_stride2()
|
||||||
|
conv_multi_channel()
|
||||||
|
conv_1x1()
|
||||||
|
conv_same_padding_3x3()
|
||||||
|
conv_explicit_padding()
|
||||||
|
conv_with_bias_3x3()
|
||||||
|
conv_batch_2()
|
||||||
|
conv_large_spatial()
|
||||||
|
|
||||||
|
print("\nDone.")
|
||||||
Reference in New Issue
Block a user