extend operation support for conv and gemm

add more tests in validation
This commit is contained in:
NiccoloN
2026-03-23 14:46:08 +01:00
parent 2676f2c7ef
commit 670d6ce94f
29 changed files with 982 additions and 29 deletions

View File

@@ -5,9 +5,11 @@ add_public_tablegen_target(ONNXToSpatialIncGen)
add_onnx_mlir_library(OMONNXToSpatial
Math/Gemm.cpp
Math/Conv.cpp
Math/MatMul.cpp
NN/Pooling.cpp
NN/ReduceMean.cpp
Tensor/ONNXConcatToTensorConcat.cpp
Tensor/ONNXReshapeToTensorReshape.cpp
Tensor/RemoveUnusedHelperOps.cpp
Utils/SpatialReducer.cpp
Utils/WeightSubdivider.cpp

View File

@@ -130,19 +130,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
});
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
// Reshape bias [numChannelsOut] -> [1, numChannelsOut] for Gemm C row-broadcasting, or use none
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
Value gemmC;
if (hasB) {
auto biasType = RankedTensorType::get({1, numChannelsOut}, cast<RankedTensorType>(b.getType()).getElementType());
gemmC = tensor::ExpandShapeOp::create(rewriter,
loc,
biasType,
b,
SmallVector<ReassociationIndices> {
{0, 1}
});
}
if (hasB)
gemmC = b;
else
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());

View File

@@ -23,6 +23,38 @@ namespace {
constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor";
static FailureOr<Value> materializeScaledConstantTensor(Value value,
float factor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (factor == 1.0f)
return value;
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
if (!constantOp)
return failure();
auto denseAttr = dyn_cast<DenseFPElementsAttr>(constantOp.getValue());
if (!denseAttr)
return failure();
SmallVector<APFloat> scaledValues;
scaledValues.reserve(denseAttr.getNumElements());
APFloat scale(factor);
bool hadFailure = false;
for (const APFloat& originalValue : denseAttr.getValues<APFloat>()) {
APFloat scaledValue(originalValue);
if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven))
hadFailure = true;
scaledValues.push_back(std::move(scaledValue));
}
if (hadFailure)
return failure();
auto scaledAttr = DenseFPElementsAttr::get(cast<RankedTensorType>(denseAttr.getType()), scaledValues);
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
}
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern;
@@ -74,10 +106,25 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
if (numOutRows <= 1)
return failure();
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
if (failed(scaledB))
return failure();
b = *scaledB;
RankedTensorType cType = nullptr;
bool cHasNumOutRows = false;
if (hasC) {
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
if (failed(scaledC))
return failure();
c = *scaledC;
cType = cast<RankedTensorType>(c.getType());
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
}
@@ -112,8 +159,8 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
aSlice,
b,
cSlice,
gemmOp.getAlphaAttr(),
gemmOp.getBetaAttr(),
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
gemmOp.getTransAAttr(),
gemmOp.getTransBAttr());
gemvOps.push_back(gemvOp.getY());
@@ -158,6 +205,12 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
if (hasC) {
cType = cast<RankedTensorType>(c.getType());
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter, gemmLoc, expandedType, c, SmallVector<ReassociationIndices>{{0, 1}});
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
}
@@ -177,19 +230,24 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
auto bShape = bType.getShape();
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
bType = cast<RankedTensorType>(b.getType());
}
if (alpha != 1.0f) {
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
auto alphaTensor = arith::ConstantOp::create(rewriter, gemmLoc, alphaTensorType, alphaTensorValue);
a = spatial::SpatVMulOp::create(rewriter, gemmLoc, a.getType(), a, alphaTensor);
auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc);
if (failed(scaledB))
return failure();
b = *scaledB;
bType = cast<RankedTensorType>(b.getType());
alpha = 1.0f;
}
if (hasC && beta != 1.0f) {
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
auto betaTensor = arith::ConstantOp::create(rewriter, gemmLoc, betaTensorType, betaTensorValue);
c = spatial::SpatVMulOp::create(rewriter, gemmLoc, c.getType(), c, betaTensor);
auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc);
if (failed(scaledC))
return failure();
c = *scaledC;
cType = cast<RankedTensorType>(c.getType());
beta = 1.0f;
}
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());

View File

@@ -0,0 +1,108 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|| !outType.hasStaticShape())
return failure();
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
return failure();
const int64_t batch = rhsType.getDimSize(0);
const int64_t k = rhsType.getDimSize(1);
const int64_t n = rhsType.getDimSize(2);
const int64_t m = lhsType.getDimSize(0);
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|| outType.getDimSize(2) != n)
return failure();
Location loc = matmulOp.getLoc();
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
Value lhsTransposed =
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
SmallVector<Value> gemmRows;
gemmRows.reserve(batch * n);
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value rhsSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
Value rhsRow = tensor::CollapseShapeOp::create(
rewriter, loc, rhsRowType, rhsSlice, SmallVector<ReassociationIndices>{{0}, {1, 2}});
auto gemmOp = ONNXGemmOp::create(rewriter,
loc,
gemmRowType,
rhsRow,
lhsTransposed,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
gemmRows.push_back(gemmOp.getY());
}
}
auto concatComputeOp =
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector<Value>(), gemmRows);
auto* concatBlock = new Block();
for (Value gemmRow : gemmRows)
concatBlock->addArgument(gemmRow.getType(), loc);
concatComputeOp.getBody().push_back(concatBlock);
rewriter.setInsertionPointToStart(concatBlock);
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments());
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
rewriter.setInsertionPointAfter(concatComputeOp);
Value gemmOut = concatComputeOp.getResult(0);
Value gemmExpanded = tensor::ExpandShapeOp::create(
rewriter, loc, gemmExpandedType, gemmOut, SmallVector<ReassociationIndices>{{0, 1}, {2}});
Value result = ONNXTransposeOp::create(
rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
rewriter.replaceOp(matmulOp, result);
return success();
}
};
} // namespace
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<MatMulRank3ToGemm>(ctx);
}
} // namespace onnx_mlir

View File

@@ -15,6 +15,10 @@ def onnxToArithConstantOp : Pat<
// ONNXMatMulOp to ONNXGemmOp patterns
def IsRank2Result: Constraint<
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
"Result is rank 2">;
def matMulAddToGemmPattern : Pat<
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
(ONNXGemmOp $A, $B, $C,
@@ -22,7 +26,8 @@ def matMulAddToGemmPattern : Pat<
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
)
),
[(IsRank2Result $matmulres)]
>;
def matMulToGemmPattern : Pat<
@@ -34,7 +39,8 @@ def matMulToGemmPattern : Pat<
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
)
),
[(IsRank2Result $matmulres)]
>;
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern

View File

@@ -1,6 +1,5 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -56,6 +55,7 @@ void ONNXToSpatialPass::runOnOperation() {
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
@@ -74,7 +74,9 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect, ONNXDialect, tensor::TensorDialect, arith::ArithDialect>();
target.addIllegalOp<ONNXMatMulOp>();
target.addDynamicallyLegalOp<ONNXMatMulOp>([](ONNXMatMulOp op) {
return cast<ShapedType>(op.getY().getType()).getRank() != 2;
});
target.addIllegalOp<ONNXGemmOp>();
target.addIllegalOp<ONNXConvOp>();
target.addIllegalOp<ONNXLRNOp>();
@@ -83,6 +85,7 @@ void ONNXToSpatialPass::runOnOperation() {
target.addIllegalOp<ONNXConcatOp>();
target.addIllegalOp<ONNXSoftmaxOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXReshapeOp>();
RewritePatternSet patterns(ctx);
patterns.add<removeLRNPattern>(ctx);
@@ -90,6 +93,7 @@ void ONNXToSpatialPass::runOnOperation() {
populateConvOpPatterns(patterns, ctx);
populatePoolingTilingPattern(patterns, ctx);
populateOnnxGemmOpPatterns(patterns, ctx);
populateReshapeConversionPattern(patterns, ctx);
populateONNXConcatToTensorConcatPattern(patterns, ctx);
populateReduceMeanConversionPattern(patterns, ctx);

View File

@@ -7,12 +7,16 @@ namespace onnx_mlir {
void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);

View File

@@ -0,0 +1,121 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape,
SmallVector<ReassociationIndices>& reassociation) {
reassociation.clear();
size_t sourceIdx = 0;
size_t resultIdx = 0;
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
int64_t sourceProduct = sourceShape[sourceIdx];
int64_t resultProduct = resultShape[resultIdx];
ReassociationIndices group;
group.push_back(sourceIdx);
while (sourceProduct != resultProduct) {
if (sourceProduct > resultProduct)
return false;
sourceIdx++;
if (sourceIdx >= sourceShape.size())
return false;
group.push_back(sourceIdx);
sourceProduct *= sourceShape[sourceIdx];
}
reassociation.push_back(group);
sourceIdx++;
resultIdx++;
}
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
}
static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape,
SmallVector<ReassociationIndices>& reassociation) {
reassociation.clear();
size_t sourceIdx = 0;
size_t resultIdx = 0;
while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) {
int64_t sourceProduct = sourceShape[sourceIdx];
int64_t resultProduct = resultShape[resultIdx];
ReassociationIndices group;
group.push_back(resultIdx);
while (resultProduct != sourceProduct) {
if (resultProduct > sourceProduct)
return false;
resultIdx++;
if (resultIdx >= resultShape.size())
return false;
group.push_back(resultIdx);
resultProduct *= resultShape[resultIdx];
}
reassociation.push_back(group);
sourceIdx++;
resultIdx++;
}
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
}
struct ONNXReshapeToTensorReshape : OpConversionPattern<ONNXReshapeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp,
ONNXReshapeOpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto sourceType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape()))
return failure();
if (sourceType == resultType) {
rewriter.replaceOp(reshapeOp, adaptor.getData());
return success();
}
SmallVector<ReassociationIndices> reassociation;
if (sourceType.getRank() > resultType.getRank()
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
return success();
}
if (sourceType.getRank() < resultType.getRank()
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
return success();
}
return failure();
}
};
} // namespace
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXReshapeToTensorReshape>(ctx);
}
} // namespace onnx_mlir

View File

@@ -1,3 +1,4 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -79,8 +80,31 @@ private:
} // namespace
static bool isChannelUseChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::CollapseShapeOp, tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp>(
op);
return isa<tensor::ExtractSliceOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
pim::PimTransposeOp>(op);
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static size_t countComputeLeafUsers(Value value) {
@@ -204,6 +228,56 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isChannelUseChainOp(resultUser)) {
SmallVector<Operation*> returnChain;
Value chainedValue = result;
Operation* chainUser = resultUser;
while (isChannelUseChainOp(chainUser)) {
returnChain.push_back(chainUser);
auto chainUses = chainUser->getResult(0).getUses();
if (rangeLength(chainUses) != 1)
break;
chainedValue = chainUser->getResult(0);
chainUser = chainUses.begin()->getOwner();
}
if (isa<func::ReturnOp>(chainUser)) {
size_t resultIndexInReturn = chainedValue.getUses().begin()->getOperandNumber();
rewriter.setInsertionPoint(yieldOp);
IRMapping mapping;
mapping.map(result, yieldValue);
Value storedValue = yieldValue;
for (Operation* op : returnChain) {
cloneMappedHelperOperands(op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
storedValue = clonedOp->getResult(0);
rewriter.setInsertionPointAfter(clonedOp);
markOpToRemove(op);
}
auto storedType = cast<ShapedType>(storedValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
Value outputTensor = outputTensors[resultIndexInReturn];
if (auto storedOp = storedValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp);
PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
outputTensor,
storedValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(storedType.getNumElements() * elementSize));
continue;
}
}
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t offset = 0;
@@ -493,6 +567,7 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
IRMapping mapping;
mapping.map(channelSourceOp, receivedValue);
for (Operation* op : llvm::reverse(clonedChain)) {
cloneMappedHelperOperands(op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);

View File

@@ -30,6 +30,24 @@ static Value stripMemRefCasts(Value value) {
return value;
}
static Value stripMemRefViewOps(Value value) {
while (true) {
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
value = expandOp.getSrc();
continue;
}
return value;
}
}
static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
Location loc,
MemRefType globalType,
@@ -204,6 +222,7 @@ struct StaticSubviewInfo {
};
static FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
value = stripMemRefViewOps(value);
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
if (!subviewOp)
return failure();
@@ -321,6 +340,77 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
}
};
struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHostToDevOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override {
auto srcSubview = getStaticSubviewInfo(copyOp.getHostSrc());
auto dstSubview = getStaticSubviewInfo(copyOp.getDeviceDst());
const bool splitSrc = succeeded(srcSubview)
&& !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides);
const bool splitDst = succeeded(dstSubview)
&& !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides);
if (!splitSrc && !splitDst)
return failure();
auto sourceType = dyn_cast<MemRefType>(copyOp.getHostSrc().getType());
auto dstType = dyn_cast<MemRefType>(copyOp.getDeviceDst().getType());
if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
if (sourceType.getElementType() != dstType.getElementType())
return failure();
if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; }))
return failure();
ArrayRef<int64_t> copyShape = splitSrc ? ArrayRef<int64_t>(srcSubview->sizes) : ArrayRef<int64_t>(dstSubview->sizes);
if (splitSrc && splitDst && copyShape != ArrayRef<int64_t>(dstSubview->sizes))
return failure();
const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8;
if (elementByteWidth <= 0)
return failure();
const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth;
if (copyOp.getSize() != totalBytes)
return failure();
const int64_t sliceBytes = copyShape.back() * elementByteWidth;
if (sliceBytes <= 0)
return failure();
SmallVector<int64_t> outerShape(copyShape.begin(), copyShape.end() - 1);
auto outerStrides = computeRowMajorStrides(outerShape);
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
rewriter.setInsertionPoint(copyOp);
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
SmallVector<int64_t> outerIndices =
outerShape.empty() ? SmallVector<int64_t>{} : delinearizeIndex(linearIndex, outerShape, outerStrides);
const int64_t srcByteOffset = copyOp.getHostSrcOffset()
+ (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
const int64_t dstByteOffset = copyOp.getDeviceDstOffset()
+ (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth)
: linearIndex * sliceBytes);
pim::PimMemCopyHostToDevOp::create(
rewriter,
copyOp.getLoc(),
splitDst ? cast<MemRefType>(dstSubview->source.getType()) : dstType,
splitDst ? dstSubview->source : copyOp.getDeviceDst(),
splitSrc ? srcSubview->source : copyOp.getHostSrc(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(dstByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(srcByteOffset)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(sliceBytes)));
}
rewriter.replaceOp(copyOp, copyOp.getDeviceDst());
return success();
}
};
static FailureOr<DenseElementsAttr> foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) {
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!allocType || !allocType.hasStaticShape())
@@ -578,6 +668,170 @@ struct FoldConstantAllocPattern final : OpRewritePattern<memref::AllocOp> {
}
};
struct FoldConstantMemCpPattern final : OpRewritePattern<pim::PimMemCopyOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override {
// Only match top-level memcp (not inside pim.core)
if (copyOp->getParentOfType<pim::PimCoreOp>())
return failure();
// dst must be an alloc with static shape
auto allocOp = copyOp.getDst().getDefiningOp<memref::AllocOp>();
if (!allocOp)
return failure();
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!allocType || !allocType.hasStaticShape())
return failure();
// The copy must cover the full destination (offsets both zero)
if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0)
return failure();
// Resolve the source through an optional subview to a get_global
auto srcSubview = getStaticSubviewInfo(copyOp.getSrc());
Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc());
auto moduleOp = copyOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, globalSource);
if (failed(denseAttr))
return failure();
// Build the folded dense attribute
DenseElementsAttr foldedAttr;
if (succeeded(srcSubview)) {
// Extract the sub-tensor from the source constant
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
if (!sourceType || !sourceType.hasStaticShape())
return failure();
if (llvm::any_of(srcSubview->strides, [](int64_t s) { return s != 1; }))
return failure();
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
const int64_t numResultElements = resultTensorType.getNumElements();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
SmallVector<Attribute> resultValues(numResultElements);
for (int64_t i = 0; i < numResultElements; ++i) {
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices))
sourceIndices.push_back(off + idx);
int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides);
resultValues[i] = sourceValues[srcLinear];
}
foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
}
else {
// Direct copy from a global — just reuse its dense attribute
auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType());
if (resultTensorType != denseAttr->getType())
return failure();
foldedAttr = *denseAttr;
}
// Verify that the alloc's remaining users are supported ops.
bool allLiveUsersAreCores = true;
for (Operation* user : allocOp->getUsers()) {
if (user == copyOp)
continue;
if (isa<memref::DeallocOp>(user))
continue;
if (isa<pim::PimCoreOp>(user))
continue;
if (isa<memref::SubViewOp>(user)) {
allLiveUsersAreCores = false;
continue;
}
return failure();
}
auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp");
if (allLiveUsersAreCores)
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(allocOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName());
if (allLiveUsersAreCores)
markWeightAlways(newGetGlobal);
rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult());
rewriter.eraseOp(copyOp);
if (allocOp.use_empty())
rewriter.eraseOp(allocOp);
return success();
}
};
struct FoldConstantCoreSubviewPattern final : OpRewritePattern<memref::SubViewOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override {
// Only handle subviews whose users are all pim.core ops.
if (subviewOp.use_empty())
return failure();
if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa<pim::PimCoreOp>(user); }))
return failure();
// Source must resolve to a constant get_global.
auto moduleOp = subviewOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return failure();
auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource()));
if (failed(denseAttr))
return failure();
// Static subview info.
auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult());
if (failed(subviewInfo))
return failure();
if (llvm::any_of(subviewInfo->strides, [](int64_t s) { return s != 1; }))
return failure();
auto sourceType = dyn_cast<RankedTensorType>(denseAttr->getType());
if (!sourceType || !sourceType.hasStaticShape())
return failure();
// Build the contiguous result type.
auto elementType = cast<MemRefType>(subviewOp.getType()).getElementType();
auto resultMemRefType = MemRefType::get(
SmallVector<int64_t>(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType);
auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType);
const int64_t numResultElements = resultTensorType.getNumElements();
// Extract the sub-tensor.
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
auto resultStrides = computeRowMajorStrides(resultTensorType.getShape());
SmallVector<Attribute> sourceValues(denseAttr->getValues<Attribute>());
SmallVector<Attribute> resultValues(numResultElements);
for (int64_t i = 0; i < numResultElements; ++i) {
auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides);
SmallVector<int64_t> sourceIndices;
sourceIndices.reserve(resultIndices.size());
for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices))
sourceIndices.push_back(off + idx);
resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)];
}
auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues);
auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview");
markWeightAlways(newGlobal);
rewriter.setInsertionPoint(subviewOp);
auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName());
markWeightAlways(newGetGlobal);
rewriter.replaceOp(subviewOp, newGetGlobal.getResult());
return success();
}
};
struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass)
@@ -591,7 +845,13 @@ struct PimConstantFoldingPass : PassWrapper<PimConstantFoldingPass, OperationPas
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);
owningPatterns
.add<FoldConstantTransposePattern, FoldConstantAllocPattern, FoldConstantCoreMapPattern, RewriteCoreSubviewCopyPattern>(
.add<FoldConstantTransposePattern,
FoldConstantAllocPattern,
FoldConstantCoreMapPattern,
RewriteCoreSubviewCopyPattern,
RewriteHostSubviewLoadPattern,
FoldConstantMemCpPattern,
FoldConstantCoreSubviewPattern>(
context);
patterns = std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
return success();

View File

@@ -0,0 +1,47 @@
# Validation Operations
ONNX test models used by `validate.py` to verify the Raptor compiler + PIM simulator pipeline.
Generated tests can be regenerated with:
```
python3 validation/operations/gen_tests.py
```
## Conv
| Test | Directory | Input | Output | Kernel | Stride | Padding | Bias | Notes |
|------|-----------|-------|--------|--------|--------|---------|------|-------|
| Simple | `conv/simple` | [1,3,3,3] | [1,1,2,2] | 2x2 | 1 | none | no | Basic conv, hand-crafted |
| With constant | `conv/with_constant` | [1,3,3,3] | [1,1,3,3] | 2x2 | 1 | SAME_UPPER | yes | Hand-crafted, constant weight+bias |
| Batch 2 | `conv/batch_2` | [2,3,3,3] | [2,1,3,3] | 2x2 | 1 | SAME_UPPER | yes | Batched input |
| Kernel 3x3 | `conv/kernel_3x3` | [1,1,5,5] | [1,1,3,3] | 3x3 | 1 | none | no | Larger kernel |
| Stride 2 | `conv/stride_2` | [1,1,6,6] | [1,1,2,2] | 3x3 | 2 | none | no | Strided convolution |
| Multi channel | `conv/multi_channel` | [1,3,5,5] | [1,4,3,3] | 3x3 | 1 | none | no | 3 in channels, 4 out channels |
| Pointwise 1x1 | `conv/pointwise_1x1` | [1,8,4,4] | [1,4,4,4] | 1x1 | 1 | none | no | Channel mixing |
| SAME padding 3x3 | `conv/same_padding_3x3` | [1,1,5,5] | [1,1,5,5] | 3x3 | 1 | SAME_UPPER | no | Spatial dims preserved |
| Explicit padding | `conv/explicit_padding` | [1,1,4,4] | [1,1,4,4] | 3x3 | 1 | [1,1,1,1] | no | Symmetric explicit pads |
| With bias 3x3 | `conv/with_bias_3x3` | [1,3,5,5] | [1,2,3,3] | 3x3 | 1 | none | yes | Multi-channel with bias |
| Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input |
## Gemm
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|------|-----------|-----------|------------|--------|--------|-------|------|------|-------|
| Default | `gemm/` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Hand-crafted, square weights |
| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N |
| With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector |
| transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight |
| Alpha/beta | `gemm/alpha_beta` | [4,64] | [64,64] | [4,64] | no | 0.5 | 0.25 | [64] | Scaled matmul + bias |
| Small | `gemm/small` | [2,8] | [8,4] | [2,4] | no | 1 | 1 | no | Tiny matrices |
| Large | `gemm/large` | [8,256] | [256,128] | [8,128] | no | 1 | 1 | no | Larger matrices |
| transB + bias | `gemm/transB_with_bias` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | [64] | Combined |
## Gemv
| Test | Directory | Input | W (weight) | Output | Bias | Notes |
|------|-----------|-------|------------|--------|------|-------|
| Simple | `gemv/simple` | [1,132] | [132,132] | [1,132] | no | Single-sample matmul |
| Constant | `gemv/constant` | _(none)_ | [132,132] | [1,132] | no | All inputs constant |
| Homogeneous const | `gemv/with_homogeneous_constant` | [1,132] | [132,132] | [1,132] | [1,132] | Bias matches output shape |
| Heterogeneous const | `gemv/with_heterogeneous_constant` | [1,132] | [132,132] | [1,132] | [1,132] | Different constant pattern |
| Scalar const | `gemv/with_scalar_constant` | [1,132] | [132,132] | [1,132] | [1,1] | Scalar bias, broadcast |

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,276 @@
#!/usr/bin/env python3
"""Generate ONNX test models for validating GEMM and Conv implementations."""
import numpy as np
import onnx
from onnx import helper, TensorProto, numpy_helper
from pathlib import Path
OPERATIONS_DIR = Path(__file__).parent
def save_model(model, directory, filename):
"""Save an ONNX model, creating the directory if needed."""
d = OPERATIONS_DIR / directory
d.mkdir(parents=True, exist_ok=True)
path = d / filename
onnx.checker.check_model(model)
onnx.save(model, str(path))
print(f" {path.relative_to(OPERATIONS_DIR)}")
# ---------------------------------------------------------------------------
# GEMM tests
# ---------------------------------------------------------------------------
def gemm_non_square():
"""GEMM with non-square weight matrix: [B, K] @ [K, N], K != N."""
B, K, N = 4, 128, 64
W = numpy_helper.from_array(np.random.default_rng(42).uniform(-1, 1, (K, N)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
graph = helper.make_graph([node], "gemm_non_square", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/non_square", "gemm_non_square.onnx")
def gemm_with_bias():
"""GEMM with bias: Y = A @ W + C."""
B, K, N = 4, 128, 128
rng = np.random.default_rng(43)
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"])
graph = helper.make_graph([node], "gemm_with_bias", [A], [Y], initializer=[W, C])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/with_bias", "gemm_with_bias.onnx")
def gemm_transB():
"""GEMM with transB=1: Y = A @ W^T."""
B, K, N = 4, 128, 64
rng = np.random.default_rng(44)
# W stored as [N, K], transposed during computation
W = numpy_helper.from_array(rng.uniform(-1, 1, (N, K)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"], transB=1)
graph = helper.make_graph([node], "gemm_transB", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/transB", "gemm_transB.onnx")
def gemm_alpha_beta():
"""GEMM with alpha and beta: Y = 0.5 * A @ W + 0.25 * C."""
B, K, N = 4, 64, 64
rng = np.random.default_rng(45)
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"], alpha=0.5, beta=0.25)
graph = helper.make_graph([node], "gemm_alpha_beta", [A], [Y], initializer=[W, C])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/alpha_beta", "gemm_alpha_beta.onnx")
def gemm_small():
"""Small GEMM: [2, 8] @ [8, 4]."""
B, K, N = 2, 8, 4
rng = np.random.default_rng(46)
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
graph = helper.make_graph([node], "gemm_small", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/small", "gemm_small.onnx")
def gemm_large():
"""Larger GEMM: [8, 256] @ [256, 128]."""
B, K, N = 8, 256, 128
rng = np.random.default_rng(47)
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
graph = helper.make_graph([node], "gemm_large", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/large", "gemm_large.onnx")
def gemm_transB_with_bias():
"""GEMM with transB and bias: Y = A @ W^T + C."""
B, K, N = 4, 128, 64
rng = np.random.default_rng(48)
W = numpy_helper.from_array(rng.uniform(-1, 1, (N, K)).astype(np.float32), name="W")
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"], transB=1)
graph = helper.make_graph([node], "gemm_transB_with_bias", [A], [Y], initializer=[W, C])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/transB_with_bias", "gemm_transB_with_bias.onnx")
# ---------------------------------------------------------------------------
# Conv tests
# ---------------------------------------------------------------------------
def conv_3x3_kernel():
"""Conv with 3x3 kernel, no padding."""
# Input: [1, 1, 5, 5], Kernel: [1, 1, 3, 3] -> Output: [1, 1, 3, 3]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3])
W = numpy_helper.from_array(
np.random.default_rng(50).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_3x3", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/kernel_3x3", "conv_kernel_3x3.onnx")
def conv_stride2():
"""Conv with 3x3 kernel and stride 2."""
# Input: [1, 1, 6, 6], Kernel: [1, 1, 3, 3], stride 2 -> Output: [1, 1, 2, 2]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 6, 6])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 2, 2])
W = numpy_helper.from_array(
np.random.default_rng(51).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[2, 2], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_stride2", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/stride_2", "conv_stride_2.onnx")
def conv_multi_channel():
"""Conv with multiple input and output channels."""
# Input: [1, 3, 5, 5], Kernel: [4, 3, 3, 3] -> Output: [1, 4, 3, 3]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 3, 3])
W = numpy_helper.from_array(
np.random.default_rng(52).uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_multi_channel", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/multi_channel", "conv_multi_channel.onnx")
def conv_1x1():
"""1x1 pointwise convolution (channel mixing)."""
# Input: [1, 8, 4, 4], Kernel: [4, 8, 1, 1] -> Output: [1, 4, 4, 4]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 8, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 4, 4])
W = numpy_helper.from_array(
np.random.default_rng(53).uniform(-1, 1, (4, 8, 1, 1)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[1, 1], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_1x1", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/pointwise_1x1", "conv_1x1.onnx")
def conv_same_padding_3x3():
"""Conv 3x3 with SAME_UPPER padding, preserving spatial dimensions."""
# Input: [1, 1, 5, 5], Kernel: [1, 1, 3, 3], SAME_UPPER -> Output: [1, 1, 5, 5]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 5, 5])
W = numpy_helper.from_array(
np.random.default_rng(54).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], auto_pad="SAME_UPPER")
graph = helper.make_graph([node], "conv_same_3x3", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/same_padding_3x3", "conv_same_padding_3x3.onnx")
def conv_explicit_padding():
"""Conv 3x3 with explicit asymmetric padding."""
# Input: [1, 1, 4, 4], Kernel: [1, 1, 3, 3], pads=[1,1,1,1] -> Output: [1, 1, 4, 4]
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 4, 4])
W = numpy_helper.from_array(
np.random.default_rng(55).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[1, 1, 1, 1])
graph = helper.make_graph([node], "conv_explicit_pad", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/explicit_padding", "conv_explicit_padding.onnx")
def conv_with_bias_3x3():
"""Conv 3x3 with bias."""
# Input: [1, 3, 5, 5], Kernel: [2, 3, 3, 3], Bias: [2] -> Output: [1, 2, 3, 3]
rng = np.random.default_rng(56)
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 3, 3])
W = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 3, 3)).astype(np.float32), name="W")
B = numpy_helper.from_array(rng.uniform(-1, 1, (2,)).astype(np.float32), name="B")
node = helper.make_node("Conv", ["X", "W", "B"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_with_bias_3x3", [X], [Y], initializer=[W, B])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/with_bias_3x3", "conv_with_bias_3x3.onnx")
def conv_batch_2():
"""Batched conv (batch=2) with SAME_UPPER padding and bias."""
# Input: [2, 3, 3, 3], Kernel: [1, 3, 2, 2], Bias: [1] -> Output: [2, 1, 3, 3]
rng = np.random.default_rng(57)
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 3, 3])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 1, 3, 3])
W = numpy_helper.from_array(rng.uniform(-1, 1, (1, 3, 2, 2)).astype(np.float32), name="W")
B = numpy_helper.from_array(rng.uniform(-1, 1, (1,)).astype(np.float32), name="B")
node = helper.make_node("Conv", ["X", "W", "B"], ["Y"],
kernel_shape=[2, 2], strides=[1, 1], auto_pad="SAME_UPPER")
graph = helper.make_graph([node], "conv_batch_2", [X], [Y], initializer=[W, B])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/batch_2", "conv_batch_2.onnx")
def conv_large_spatial():
"""Conv on larger spatial input: [1, 1, 8, 8] with 3x3 kernel."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 8, 8])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 6, 6])
W = numpy_helper.from_array(
np.random.default_rng(58).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W")
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_large_spatial", [X], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/large_spatial", "conv_large_spatial.onnx")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("Generating GEMM tests:")
gemm_non_square()
gemm_with_bias()
gemm_transB()
gemm_alpha_beta()
gemm_small()
gemm_large()
gemm_transB_with_bias()
print("\nGenerating Conv tests:")
conv_3x3_kernel()
conv_stride2()
conv_multi_channel()
conv_1x1()
conv_same_padding_3x3()
conv_explicit_padding()
conv_with_bias_3x3()
conv_batch_2()
conv_large_spatial()
print("\nDone.")