extend operation support for conv and gemm
add more tests in validation
This commit is contained in:
@@ -5,9 +5,11 @@ add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
add_onnx_mlir_library(OMONNXToSpatial
|
||||
Math/Gemm.cpp
|
||||
Math/Conv.cpp
|
||||
Math/MatMul.cpp
|
||||
NN/Pooling.cpp
|
||||
NN/ReduceMean.cpp
|
||||
Tensor/ONNXConcatToTensorConcat.cpp
|
||||
Tensor/ONNXReshapeToTensorReshape.cpp
|
||||
Tensor/RemoveUnusedHelperOps.cpp
|
||||
Utils/SpatialReducer.cpp
|
||||
Utils/WeightSubdivider.cpp
|
||||
|
||||
@@ -130,19 +130,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
});
|
||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||
|
||||
// Reshape bias [numChannelsOut] -> [1, numChannelsOut] for Gemm C row-broadcasting, or use none
|
||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
Value gemmC;
|
||||
if (hasB) {
|
||||
auto biasType = RankedTensorType::get({1, numChannelsOut}, cast<RankedTensorType>(b.getType()).getElementType());
|
||||
gemmC = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
biasType,
|
||||
b,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
});
|
||||
}
|
||||
if (hasB)
|
||||
gemmC = b;
|
||||
else
|
||||
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
|
||||
@@ -23,6 +23,38 @@ namespace {
|
||||
|
||||
constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
|
||||
|
||||
static FailureOr<Value> materializeScaledConstantTensor(Value value,
|
||||
float factor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (factor == 1.0f)
|
||||
return value;
|
||||
|
||||
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
|
||||
if (!constantOp)
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseFPElementsAttr>(constantOp.getValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
SmallVector<APFloat> scaledValues;
|
||||
scaledValues.reserve(denseAttr.getNumElements());
|
||||
APFloat scale(factor);
|
||||
bool hadFailure = false;
|
||||
for (const APFloat& originalValue : denseAttr.getValues<APFloat>()) {
|
||||
APFloat scaledValue(originalValue);
|
||||
if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven))
|
||||
hadFailure = true;
|
||||
scaledValues.push_back(std::move(scaledValue));
|
||||
}
|
||||
if (hadFailure)
|
||||
return failure();
|
||||
|
||||
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
|
||||
}
|
||||
|
||||
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -74,10 +106,25 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
if (numOutRows <= 1)
|
||||
return failure();
|
||||
|
||||
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
||||
if (failed(scaledB))
|
||||
return failure();
|
||||
b = *scaledB;
|
||||
|
||||
RankedTensorType cType = nullptr;
|
||||
bool cHasNumOutRows = false;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
if (failed(scaledC))
|
||||
return failure();
|
||||
c = *scaledC;
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
|
||||
cType = expandedType;
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
|
||||
}
|
||||
@@ -112,8 +159,8 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
aSlice,
|
||||
b,
|
||||
cSlice,
|
||||
gemmOp.getAlphaAttr(),
|
||||
gemmOp.getBetaAttr(),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
gemmOp.getTransAAttr(),
|
||||
gemmOp.getTransBAttr());
|
||||
gemvOps.push_back(gemvOp.getY());
|
||||
@@ -158,6 +205,12 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
||||
if (hasC) {
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
|
||||
if (cType.getRank() == 1) {
|
||||
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
|
||||
c = tensor::ExpandShapeOp::create(rewriter, gemmLoc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
|
||||
cType = expandedType;
|
||||
}
|
||||
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
|
||||
}
|
||||
|
||||
@@ -177,19 +230,24 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
auto bShape = bType.getShape();
|
||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
}
|
||||
|
||||
if (alpha != 1.0f) {
|
||||
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
|
||||
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
|
||||
auto alphaTensor = arith::ConstantOp::create(rewriter, gemmLoc, alphaTensorType, alphaTensorValue);
|
||||
a = spatial::SpatVMulOp::create(rewriter, gemmLoc, a.getType(), a, alphaTensor);
|
||||
auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc);
|
||||
if (failed(scaledB))
|
||||
return failure();
|
||||
b = *scaledB;
|
||||
bType = cast<RankedTensorType>(b.getType());
|
||||
alpha = 1.0f;
|
||||
}
|
||||
if (hasC && beta != 1.0f) {
|
||||
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
|
||||
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
|
||||
auto betaTensor = arith::ConstantOp::create(rewriter, gemmLoc, betaTensorType, betaTensorValue);
|
||||
c = spatial::SpatVMulOp::create(rewriter, gemmLoc, c.getType(), c, betaTensor);
|
||||
auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc);
|
||||
if (failed(scaledC))
|
||||
return failure();
|
||||
c = *scaledC;
|
||||
cType = cast<RankedTensorType>(c.getType());
|
||||
beta = 1.0f;
|
||||
}
|
||||
|
||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
||||
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
||||
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||
|| !outType.hasStaticShape())
|
||||
return failure();
|
||||
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
|
||||
return failure();
|
||||
|
||||
const int64_t batch = rhsType.getDimSize(0);
|
||||
const int64_t k = rhsType.getDimSize(1);
|
||||
const int64_t n = rhsType.getDimSize(2);
|
||||
const int64_t m = lhsType.getDimSize(0);
|
||||
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|
||||
|| outType.getDimSize(2) != n)
|
||||
return failure();
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
|
||||
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
|
||||
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
|
||||
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
|
||||
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
|
||||
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
|
||||
|
||||
Value lhsTransposed =
|
||||
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
|
||||
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
SmallVector<Value> gemmRows;
|
||||
gemmRows.reserve(batch * n);
|
||||
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
||||
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
|
||||
SmallVector<OpFoldResult> sizes = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value rhsSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
|
||||
Value rhsRow = tensor::CollapseShapeOp::create(
|
||||
rewriter, loc, rhsRowType, rhsSlice, SmallVector<ReassociationIndices>{{0}, {1, 2}});
|
||||
|
||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmRowType,
|
||||
rhsRow,
|
||||
lhsTransposed,
|
||||
none,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false));
|
||||
gemmRows.push_back(gemmOp.getY());
|
||||
}
|
||||
}
|
||||
|
||||
auto concatComputeOp =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector<Value>(), gemmRows);
|
||||
|
||||
auto* concatBlock = new Block();
|
||||
for (Value gemmRow : gemmRows)
|
||||
concatBlock->addArgument(gemmRow.getType(), loc);
|
||||
concatComputeOp.getBody().push_back(concatBlock);
|
||||
rewriter.setInsertionPointToStart(concatBlock);
|
||||
|
||||
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments());
|
||||
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||
|
||||
rewriter.setInsertionPointAfter(concatComputeOp);
|
||||
Value gemmOut = concatComputeOp.getResult(0);
|
||||
Value gemmExpanded = tensor::ExpandShapeOp::create(
|
||||
rewriter, loc, gemmExpandedType, gemmOut, SmallVector<ReassociationIndices>{{0, 1}, {2}});
|
||||
Value result = ONNXTransposeOp::create(
|
||||
rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<MatMulRank3ToGemm>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -15,6 +15,10 @@ def onnxToArithConstantOp : Pat<
|
||||
|
||||
// ONNXMatMulOp to ONNXGemmOp patterns
|
||||
|
||||
def IsRank2Result: Constraint<
|
||||
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
|
||||
"Result is rank 2">;
|
||||
|
||||
def matMulAddToGemmPattern : Pat<
|
||||
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
|
||||
(ONNXGemmOp $A, $B, $C,
|
||||
@@ -22,7 +26,8 @@ def matMulAddToGemmPattern : Pat<
|
||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
||||
)
|
||||
),
|
||||
[(IsRank2Result $matmulres)]
|
||||
>;
|
||||
|
||||
def matMulToGemmPattern : Pat<
|
||||
@@ -34,7 +39,8 @@ def matMulToGemmPattern : Pat<
|
||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
|
||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
||||
)
|
||||
),
|
||||
[(IsRank2Result $matmulres)]
|
||||
>;
|
||||
|
||||
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
@@ -56,6 +55,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
|
||||
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
|
||||
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
|
||||
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
|
||||
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
||||
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
||||
@@ -74,7 +74,9 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
|
||||
target.addIllegalOp<ONNXMatMulOp>();
|
||||
target.addDynamicallyLegalOp<ONNXMatMulOp>([](ONNXMatMulOp op) {
|
||||
return cast<ShapedType>(op.getY().getType()).getRank() != 2;
|
||||
});
|
||||
target.addIllegalOp<ONNXGemmOp>();
|
||||
target.addIllegalOp<ONNXConvOp>();
|
||||
target.addIllegalOp<ONNXLRNOp>();
|
||||
@@ -83,6 +85,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
target.addIllegalOp<ONNXConcatOp>();
|
||||
target.addIllegalOp<ONNXSoftmaxOp>();
|
||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||
target.addIllegalOp<ONNXReshapeOp>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<removeLRNPattern>(ctx);
|
||||
@@ -90,6 +93,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
populateConvOpPatterns(patterns, ctx);
|
||||
populatePoolingTilingPattern(patterns, ctx);
|
||||
populateOnnxGemmOpPatterns(patterns, ctx);
|
||||
populateReshapeConversionPattern(patterns, ctx);
|
||||
|
||||
populateONNXConcatToTensorConcatPattern(patterns, ctx);
|
||||
populateReduceMeanConversionPattern(patterns, ctx);
|
||||
|
||||
@@ -7,12 +7,16 @@ namespace onnx_mlir {
|
||||
|
||||
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
|
||||
ArrayRef<int64_t> resultShape,
|
||||
SmallVector<ReassociationIndices>& reassociation) {
|
||||
reassociation.clear();
|
||||
|
||||
size_t sourceIdx = 0;
|
||||
size_t resultIdx = 0;
|
||||
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
|
||||
int64_t sourceProduct = sourceShape[sourceIdx];
|
||||
int64_t resultProduct = resultShape[resultIdx];
|
||||
|
||||
ReassociationIndices group;
|
||||
group.push_back(sourceIdx);
|
||||
while (sourceProduct != resultProduct) {
|
||||
if (sourceProduct > resultProduct)
|
||||
return false;
|
||||
sourceIdx++;
|
||||
if (sourceIdx >= sourceShape.size())
|
||||
return false;
|
||||
group.push_back(sourceIdx);
|
||||
sourceProduct *= sourceShape[sourceIdx];
|
||||
}
|
||||
|
||||
reassociation.push_back(group);
|
||||
sourceIdx++;
|
||||
resultIdx++;
|
||||
}
|
||||
|
||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||
}
|
||||
|
||||
static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
|
||||
ArrayRef<int64_t> resultShape,
|
||||
SmallVector<ReassociationIndices>& reassociation) {
|
||||
reassociation.clear();
|
||||
|
||||
size_t sourceIdx = 0;
|
||||
size_t resultIdx = 0;
|
||||
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
|
||||
int64_t sourceProduct = sourceShape[sourceIdx];
|
||||
int64_t resultProduct = resultShape[resultIdx];
|
||||
|
||||
ReassociationIndices group;
|
||||
group.push_back(resultIdx);
|
||||
while (resultProduct != sourceProduct) {
|
||||
if (resultProduct > sourceProduct)
|
||||
return false;
|
||||
resultIdx++;
|
||||
if (resultIdx >= resultShape.size())
|
||||
return false;
|
||||
group.push_back(resultIdx);
|
||||
resultProduct *= resultShape[resultIdx];
|
||||
}
|
||||
|
||||
reassociation.push_back(group);
|
||||
sourceIdx++;
|
||||
resultIdx++;
|
||||
}
|
||||
|
||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||
}
|
||||
|
||||
struct ONNXReshapeToTensorReshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp,
|
||||
ONNXReshapeOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape()))
|
||||
return failure();
|
||||
|
||||
if (sourceType == resultType) {
|
||||
rewriter.replaceOp(reshapeOp, adaptor.getData());
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
if (sourceType.getRank() > resultType.getRank()
|
||||
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||
return success();
|
||||
}
|
||||
|
||||
if (sourceType.getRank() < resultType.getRank()
|
||||
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
|
||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ONNXReshapeToTensorReshape>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
@@ -79,8 +80,31 @@ private:
|
||||
} // namespace
|
||||
|
||||
static bool isChannelUseChainOp(Operation* op) {
|
||||
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>(
|
||||
op);
|
||||
return isa<tensor::ExtractSliceOp,
|
||||
tensor::CollapseShapeOp,
|
||||
tensor::ExpandShapeOp,
|
||||
tensor::CastOp,
|
||||
tosa::ReshapeOp,
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (mapping.lookupOrNull(operand))
|
||||
continue;
|
||||
|
||||
Operation* definingOp = operand.getDefiningOp();
|
||||
if (!definingOp)
|
||||
continue;
|
||||
|
||||
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||
continue;
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
rewriter.setInsertionPointAfter(clonedOp);
|
||||
}
|
||||
}
|
||||
|
||||
static size_t countComputeLeafUsers(Value value) {
|
||||
@@ -204,6 +228,56 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
||||
OpOperand& resultUse = *resultUses.begin();
|
||||
Operation* resultUser = resultUse.getOwner();
|
||||
|
||||
if (isChannelUseChainOp(resultUser)) {
|
||||
SmallVector<Operation*> returnChain;
|
||||
Value chainedValue = result;
|
||||
Operation* chainUser = resultUser;
|
||||
|
||||
while (isChannelUseChainOp(chainUser)) {
|
||||
returnChain.push_back(chainUser);
|
||||
auto chainUses = chainUser->getResult(0).getUses();
|
||||
if (rangeLength(chainUses) != 1)
|
||||
break;
|
||||
chainedValue = chainUser->getResult(0);
|
||||
chainUser = chainUses.begin()->getOwner();
|
||||
}
|
||||
|
||||
if (isa<func::ReturnOp>(chainUser)) {
|
||||
size_t resultIndexInReturn = chainedValue.getUses().begin()->getOperandNumber();
|
||||
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
IRMapping mapping;
|
||||
mapping.map(result, yieldValue);
|
||||
|
||||
Value storedValue = yieldValue;
|
||||
for (Operation* op : returnChain) {
|
||||
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
storedValue = clonedOp->getResult(0);
|
||||
rewriter.setInsertionPointAfter(clonedOp);
|
||||
markOpToRemove(op);
|
||||
}
|
||||
|
||||
auto storedType = cast<ShapedType>(storedValue.getType());
|
||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||
|
||||
Value outputTensor = outputTensors[resultIndexInReturn];
|
||||
if (auto storedOp = storedValue.getDefiningOp())
|
||||
rewriter.setInsertionPointAfter(storedOp);
|
||||
PimMemCopyDevToHostOp::create(rewriter,
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
outputTensor,
|
||||
storedValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(storedType.getNumElements() * elementSize));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (isa<func::ReturnOp>(resultUser)) {
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
size_t offset = 0;
|
||||
@@ -493,6 +567,7 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
||||
IRMapping mapping;
|
||||
mapping.map(channelSourceOp, receivedValue);
|
||||
for (Operation* op : llvm::reverse(clonedChain)) {
|
||||
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
|
||||
@@ -30,6 +30,24 @@ static Value stripMemRefCasts(Value value) {
|
||||
return value;
|
||||
}
|
||||
|
||||
static Value stripMemRefViewOps(Value value) {
|
||||
while (true) {
|
||||
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||
Location loc,
|
||||
MemRefType globalType,
|
||||
@@ -204,6 +222,7 @@ struct StaticSubviewInfo {
|
||||
};
|
||||
|
||||
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
value = stripMemRefViewOps(value);
|
||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||
if (!subviewOp)
|
||||
return failure();
|
||||
@@ -321,6 +340,77 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
|
||||
}
|
||||
};
|
||||
|
||||
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getHostSrc());
|
||||
auto dstSubview = getStaticSubviewInfo(copyOp.getDeviceDst());
|
||||
const bool splitSrc = succeeded(srcSubview)
|
||||
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
|
||||
const bool splitDst = succeeded(dstSubview)
|
||||
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
|
||||
if (!splitSrc && !splitDst)
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<MemRefType>(copyOp.getHostSrc().getType());
|
||||
auto dstType = dyn_cast<MemRefType>(copyOp.getDeviceDst().getType());
|
||||
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
|
||||
return failure();
|
||||
if (sourceType.getElementType() != dstType.getElementType())
|
||||
return failure();
|
||||
|
||||
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
|
||||
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
|
||||
return failure();
|
||||
|
||||
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
|
||||
if (elementByteWidth <= 0)
|
||||
return failure();
|
||||
|
||||
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
|
||||
if (copyOp.getSize() != totalBytes)
|
||||
return failure();
|
||||
|
||||
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
|
||||
if (sliceBytes <= 0)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
|
||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
|
||||
const int64_t srcByteOffset = copyOp.getHostSrcOffset()
|
||||
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
const int64_t dstByteOffset = copyOp.getDeviceDstOffset()
|
||||
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
|
||||
: linearIndex * sliceBytes);
|
||||
pim::PimMemCopyHostToDevOp::create(
|
||||
rewriter,
|
||||
copyOp.getLoc(),
|
||||
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
|
||||
splitDst ? dstSubview->source : copyOp.getDeviceDst(),
|
||||
splitSrc ? srcSubview->source : copyOp.getHostSrc(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
|
||||
}
|
||||
|
||||
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
|
||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
@@ -578,6 +668,170 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
|
||||
// Only match top-level memcp (not inside pim.core)
|
||||
if (copyOp->getParentOfType<pim::PimCoreOp>())
|
||||
return failure();
|
||||
|
||||
// dst must be an alloc with static shape
|
||||
auto allocOp = copyOp.getDst().getDefiningOp<memref::AllocOp>();
|
||||
if (!allocOp)
|
||||
return failure();
|
||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||
if (!allocType || !allocType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// The copy must cover the full destination (offsets both zero)
|
||||
if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0)
|
||||
return failure();
|
||||
|
||||
// Resolve the source through an optional subview to a get_global
|
||||
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
|
||||
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc());
|
||||
|
||||
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
|
||||
// Build the folded dense attribute
|
||||
DenseElementsAttr foldedAttr;
|
||||
if (succeeded(srcSubview)) {
|
||||
// Extract the sub-tensor from the source constant
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
if (llvm::any_of(srcSubview->strides, [](int64_t s) { return s != 1; }))
|
||||
return failure();
|
||||
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(numResultElements);
|
||||
|
||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
|
||||
resultValues[i] = sourceValues[srcLinear];
|
||||
}
|
||||
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
}
|
||||
else {
|
||||
// Direct copy from a global — just reuse its dense attribute
|
||||
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
|
||||
if (resultTensorType != denseAttr->getType())
|
||||
return failure();
|
||||
foldedAttr = *denseAttr;
|
||||
}
|
||||
|
||||
// Verify that the alloc's remaining users are supported ops.
|
||||
bool allLiveUsersAreCores = true;
|
||||
for (Operation* user : allocOp->getUsers()) {
|
||||
if (user == copyOp)
|
||||
continue;
|
||||
if (isa<memref::DeallocOp>(user))
|
||||
continue;
|
||||
if (isa<pim::PimCoreOp>(user))
|
||||
continue;
|
||||
if (isa<memref::SubViewOp>(user)) {
|
||||
allLiveUsersAreCores = false;
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp");
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(allocOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
|
||||
if (allLiveUsersAreCores)
|
||||
markWeightAlways(newGetGlobal);
|
||||
|
||||
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
|
||||
rewriter.eraseOp(copyOp);
|
||||
if (allocOp.use_empty())
|
||||
rewriter.eraseOp(allocOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
|
||||
// Only handle subviews whose users are all pim.core ops.
|
||||
if (subviewOp.use_empty())
|
||||
return failure();
|
||||
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
|
||||
return failure();
|
||||
|
||||
// Source must resolve to a constant get_global.
|
||||
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
|
||||
if (failed(denseAttr))
|
||||
return failure();
|
||||
|
||||
// Static subview info.
|
||||
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
|
||||
if (failed(subviewInfo))
|
||||
return failure();
|
||||
if (llvm::any_of(subviewInfo->strides, [](int64_t s) { return s != 1; }))
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// Build the contiguous result type.
|
||||
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
|
||||
auto resultMemRefType = MemRefType::get(
|
||||
SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
|
||||
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
|
||||
const int64_t numResultElements = resultTensorType.getNumElements();
|
||||
|
||||
// Extract the sub-tensor.
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
|
||||
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(numResultElements);
|
||||
for (int64_t i = 0; i < numResultElements; ++i) {
|
||||
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
|
||||
SmallVector<int64_t> sourceIndices;
|
||||
sourceIndices.reserve(resultIndices.size());
|
||||
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
|
||||
sourceIndices.push_back(off + idx);
|
||||
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
|
||||
}
|
||||
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
|
||||
|
||||
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
|
||||
markWeightAlways(newGlobal);
|
||||
|
||||
rewriter.setInsertionPoint(subviewOp);
|
||||
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
|
||||
markWeightAlways(newGetGlobal);
|
||||
|
||||
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
|
||||
|
||||
@@ -591,7 +845,13 @@ struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPas
|
||||
for (RegisteredOperationName op : context->getRegisteredOperations())
|
||||
op.getCanonicalizationPatterns(owningPatterns, context);
|
||||
owningPatterns
|
||||
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, RewriteCoreSubviewCopyPattern>(
|
||||
.add<FoldConstantTransposePattern,
|
||||
FoldConstantAllocPattern,
|
||||
FoldConstantCoreMapPattern,
|
||||
RewriteCoreSubviewCopyPattern,
|
||||
RewriteHostSubviewLoadPattern,
|
||||
FoldConstantMemCpPattern,
|
||||
FoldConstantCoreSubviewPattern>(
|
||||
context);
|
||||
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
|
||||
return success();
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
# Validation Operations
|
||||
|
||||
ONNX test models used by `validate.py` to verify the Raptor compiler + PIM simulator pipeline.
|
||||
|
||||
Generated tests can be regenerated with:
|
||||
```
|
||||
python3 validation/operations/gen_tests.py
|
||||
```
|
||||
|
||||
## Conv
|
||||
|
||||
| Test | Directory | Input | Output | Kernel | Stride | Padding | Bias | Notes |
|
||||
|------|-----------|-------|--------|--------|--------|---------|------|-------|
|
||||
| Simple | `conv/simple` | [1,3,3,3] | [1,1,2,2] | 2x2 | 1 | none | no | Basic conv, hand-crafted |
|
||||
| With constant | `conv/with_constant` | [1,3,3,3] | [1,1,3,3] | 2x2 | 1 | SAME_UPPER | yes | Hand-crafted, constant weight+bias |
|
||||
| Batch 2 | `conv/batch_2` | [2,3,3,3] | [2,1,3,3] | 2x2 | 1 | SAME_UPPER | yes | Batched input |
|
||||
| Kernel 3x3 | `conv/kernel_3x3` | [1,1,5,5] | [1,1,3,3] | 3x3 | 1 | none | no | Larger kernel |
|
||||
| Stride 2 | `conv/stride_2` | [1,1,6,6] | [1,1,2,2] | 3x3 | 2 | none | no | Strided convolution |
|
||||
| Multi channel | `conv/multi_channel` | [1,3,5,5] | [1,4,3,3] | 3x3 | 1 | none | no | 3 in channels, 4 out channels |
|
||||
| Pointwise 1x1 | `conv/pointwise_1x1` | [1,8,4,4] | [1,4,4,4] | 1x1 | 1 | none | no | Channel mixing |
|
||||
| SAME padding 3x3 | `conv/same_padding_3x3` | [1,1,5,5] | [1,1,5,5] | 3x3 | 1 | SAME_UPPER | no | Spatial dims preserved |
|
||||
| Explicit padding | `conv/explicit_padding` | [1,1,4,4] | [1,1,4,4] | 3x3 | 1 | [1,1,1,1] | no | Symmetric explicit pads |
|
||||
| With bias 3x3 | `conv/with_bias_3x3` | [1,3,5,5] | [1,2,3,3] | 3x3 | 1 | none | yes | Multi-channel with bias |
|
||||
| Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input |
|
||||
|
||||
## Gemm
|
||||
|
||||
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|
||||
|------|-----------|-----------|------------|--------|--------|-------|------|------|-------|
|
||||
| Default | `gemm/` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Hand-crafted, square weights |
|
||||
| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N |
|
||||
| With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector |
|
||||
| transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight |
|
||||
| Alpha/beta | `gemm/alpha_beta` | [4,64] | [64,64] | [4,64] | no | 0.5 | 0.25 | [64] | Scaled matmul + bias |
|
||||
| Small | `gemm/small` | [2,8] | [8,4] | [2,4] | no | 1 | 1 | no | Tiny matrices |
|
||||
| Large | `gemm/large` | [8,256] | [256,128] | [8,128] | no | 1 | 1 | no | Larger matrices |
|
||||
| transB + bias | `gemm/transB_with_bias` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | [64] | Combined |
|
||||
|
||||
## Gemv
|
||||
|
||||
| Test | Directory | Input | W (weight) | Output | Bias | Notes |
|
||||
|------|-----------|-------|------------|--------|------|-------|
|
||||
| Simple | `gemv/simple` | [1,132] | [132,132] | [1,132] | no | Single-sample matmul |
|
||||
| Constant | `gemv/constant` | _(none)_ | [132,132] | [1,132] | no | All inputs constant |
|
||||
| Homogeneous const | `gemv/with_homogeneous_constant` | [1,132] | [132,132] | [1,132] | [1,132] | Bias matches output shape |
|
||||
| Heterogeneous const | `gemv/with_heterogeneous_constant` | [1,132] | [132,132] | [1,132] | [1,132] | Different constant pattern |
|
||||
| Scalar const | `gemv/with_scalar_constant` | [1,132] | [132,132] | [1,132] | [1,1] | Scalar bias, broadcast |
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,276 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate ONNX test models for validating GEMM and Conv implementations."""
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
from onnx import helper, TensorProto, numpy_helper
|
||||
from pathlib import Path
|
||||
|
||||
OPERATIONS_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def save_model(model, directory, filename):
|
||||
"""Save an ONNX model, creating the directory if needed."""
|
||||
d = OPERATIONS_DIR / directory
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
path = d / filename
|
||||
onnx.checker.check_model(model)
|
||||
onnx.save(model, str(path))
|
||||
print(f" {path.relative_to(OPERATIONS_DIR)}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GEMM tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def gemm_non_square():
|
||||
"""GEMM with non-square weight matrix: [B, K] @ [K, N], K != N."""
|
||||
B, K, N = 4, 128, 64
|
||||
W = numpy_helper.from_array(np.random.default_rng(42).uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
|
||||
graph = helper.make_graph([node], "gemm_non_square", [A], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/non_square", "gemm_non_square.onnx")
|
||||
|
||||
|
||||
def gemm_with_bias():
|
||||
"""GEMM with bias: Y = A @ W + C."""
|
||||
B, K, N = 4, 128, 128
|
||||
rng = np.random.default_rng(43)
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"])
|
||||
graph = helper.make_graph([node], "gemm_with_bias", [A], [Y], initializer=[W, C])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/with_bias", "gemm_with_bias.onnx")
|
||||
|
||||
|
||||
def gemm_transB():
|
||||
"""GEMM with transB=1: Y = A @ W^T."""
|
||||
B, K, N = 4, 128, 64
|
||||
rng = np.random.default_rng(44)
|
||||
# W stored as [N, K], transposed during computation
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (N, K)).astype(np.float32), name="W")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W"], ["Y"], transB=1)
|
||||
graph = helper.make_graph([node], "gemm_transB", [A], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/transB", "gemm_transB.onnx")
|
||||
|
||||
|
||||
def gemm_alpha_beta():
|
||||
"""GEMM with alpha and beta: Y = 0.5 * A @ W + 0.25 * C."""
|
||||
B, K, N = 4, 64, 64
|
||||
rng = np.random.default_rng(45)
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"], alpha=0.5, beta=0.25)
|
||||
graph = helper.make_graph([node], "gemm_alpha_beta", [A], [Y], initializer=[W, C])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/alpha_beta", "gemm_alpha_beta.onnx")
|
||||
|
||||
|
||||
def gemm_small():
|
||||
"""Small GEMM: [2, 8] @ [8, 4]."""
|
||||
B, K, N = 2, 8, 4
|
||||
rng = np.random.default_rng(46)
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
|
||||
graph = helper.make_graph([node], "gemm_small", [A], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/small", "gemm_small.onnx")
|
||||
|
||||
|
||||
def gemm_large():
|
||||
"""Larger GEMM: [8, 256] @ [256, 128]."""
|
||||
B, K, N = 8, 256, 128
|
||||
rng = np.random.default_rng(47)
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
|
||||
graph = helper.make_graph([node], "gemm_large", [A], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/large", "gemm_large.onnx")
|
||||
|
||||
|
||||
def gemm_transB_with_bias():
|
||||
"""GEMM with transB and bias: Y = A @ W^T + C."""
|
||||
B, K, N = 4, 128, 64
|
||||
rng = np.random.default_rng(48)
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (N, K)).astype(np.float32), name="W")
|
||||
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"], transB=1)
|
||||
graph = helper.make_graph([node], "gemm_transB_with_bias", [A], [Y], initializer=[W, C])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/transB_with_bias", "gemm_transB_with_bias.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Conv tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def conv_3x3_kernel():
|
||||
"""Conv with 3x3 kernel, no padding."""
|
||||
# Input: [1, 1, 5, 5], Kernel: [1, 1, 3, 3] -> Output: [1, 1, 3, 3]
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
|
||||
W = numpy_helper.from_array(
|
||||
np.random.default_rng(50).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "conv_3x3", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/kernel_3x3", "conv_kernel_3x3.onnx")
|
||||
|
||||
|
||||
def conv_stride2():
|
||||
"""Conv with 3x3 kernel and stride 2."""
|
||||
# Input: [1, 1, 6, 6], Kernel: [1, 1, 3, 3], stride 2 -> Output: [1, 1, 2, 2]
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 6, 6])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 2, 2])
|
||||
W = numpy_helper.from_array(
|
||||
np.random.default_rng(51).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[2, 2], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "conv_stride2", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/stride_2", "conv_stride_2.onnx")
|
||||
|
||||
|
||||
def conv_multi_channel():
|
||||
"""Conv with multiple input and output channels."""
|
||||
# Input: [1, 3, 5, 5], Kernel: [4, 3, 3, 3] -> Output: [1, 4, 3, 3]
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 3, 3])
|
||||
W = numpy_helper.from_array(
|
||||
np.random.default_rng(52).uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
|
||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "conv_multi_channel", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/multi_channel", "conv_multi_channel.onnx")
|
||||
|
||||
|
||||
def conv_1x1():
|
||||
"""1x1 pointwise convolution (channel mixing)."""
|
||||
# Input: [1, 8, 4, 4], Kernel: [4, 8, 1, 1] -> Output: [1, 4, 4, 4]
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 8, 4, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 4, 4])
|
||||
W = numpy_helper.from_array(
|
||||
np.random.default_rng(53).uniform(-1, 1, (4, 8, 1, 1)).astype(np.float32), name="W")
|
||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||
kernel_shape=[1, 1], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "conv_1x1", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/pointwise_1x1", "conv_1x1.onnx")
|
||||
|
||||
|
||||
def conv_same_padding_3x3():
|
||||
"""Conv 3x3 with SAME_UPPER padding, preserving spatial dimensions."""
|
||||
# Input: [1, 1, 5, 5], Kernel: [1, 1, 3, 3], SAME_UPPER -> Output: [1, 1, 5, 5]
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 5, 5])
|
||||
W = numpy_helper.from_array(
|
||||
np.random.default_rng(54).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[1, 1], auto_pad="SAME_UPPER")
|
||||
graph = helper.make_graph([node], "conv_same_3x3", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/same_padding_3x3", "conv_same_padding_3x3.onnx")
|
||||
|
||||
|
||||
def conv_explicit_padding():
|
||||
"""Conv 3x3 with explicit asymmetric padding."""
|
||||
# Input: [1, 1, 4, 4], Kernel: [1, 1, 3, 3], pads=[1,1,1,1] -> Output: [1, 1, 4, 4]
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 4, 4])
|
||||
W = numpy_helper.from_array(
|
||||
np.random.default_rng(55).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[1, 1], pads=[1, 1, 1, 1])
|
||||
graph = helper.make_graph([node], "conv_explicit_pad", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/explicit_padding", "conv_explicit_padding.onnx")
|
||||
|
||||
|
||||
def conv_with_bias_3x3():
|
||||
"""Conv 3x3 with bias."""
|
||||
# Input: [1, 3, 5, 5], Kernel: [2, 3, 3, 3], Bias: [2] -> Output: [1, 2, 3, 3]
|
||||
rng = np.random.default_rng(56)
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 3, 3])
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 3, 3)).astype(np.float32), name="W")
|
||||
B = numpy_helper.from_array(rng.uniform(-1, 1, (2,)).astype(np.float32), name="B")
|
||||
node = helper.make_node("Conv", ["X", "W", "B"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "conv_with_bias_3x3", [X], [Y], initializer=[W, B])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/with_bias_3x3", "conv_with_bias_3x3.onnx")
|
||||
|
||||
|
||||
def conv_batch_2():
|
||||
"""Batched conv (batch=2) with SAME_UPPER padding and bias."""
|
||||
# Input: [2, 3, 3, 3], Kernel: [1, 3, 2, 2], Bias: [1] -> Output: [2, 1, 3, 3]
|
||||
rng = np.random.default_rng(57)
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 3, 3])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 1, 3, 3])
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (1, 3, 2, 2)).astype(np.float32), name="W")
|
||||
B = numpy_helper.from_array(rng.uniform(-1, 1, (1,)).astype(np.float32), name="B")
|
||||
node = helper.make_node("Conv", ["X", "W", "B"], ["Y"],
|
||||
kernel_shape=[2, 2], strides=[1, 1], auto_pad="SAME_UPPER")
|
||||
graph = helper.make_graph([node], "conv_batch_2", [X], [Y], initializer=[W, B])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/batch_2", "conv_batch_2.onnx")
|
||||
|
||||
|
||||
def conv_large_spatial():
|
||||
"""Conv on larger spatial input: [1, 1, 8, 8] with 3x3 kernel."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 8, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 6, 6])
|
||||
W = numpy_helper.from_array(
|
||||
np.random.default_rng(58).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
|
||||
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
graph = helper.make_graph([node], "conv_large_spatial", [X], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "conv/large_spatial", "conv_large_spatial.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Generating GEMM tests:")
|
||||
gemm_non_square()
|
||||
gemm_with_bias()
|
||||
gemm_transB()
|
||||
gemm_alpha_beta()
|
||||
gemm_small()
|
||||
gemm_large()
|
||||
gemm_transB_with_bias()
|
||||
|
||||
print("\nGenerating Conv tests:")
|
||||
conv_3x3_kernel()
|
||||
conv_stride2()
|
||||
conv_multi_channel()
|
||||
conv_1x1()
|
||||
conv_same_padding_3x3()
|
||||
conv_explicit_padding()
|
||||
conv_with_bias_3x3()
|
||||
conv_batch_2()
|
||||
conv_large_spatial()
|
||||
|
||||
print("\nDone.")
|
||||
Reference in New Issue
Block a user