diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index 85beb04..ea95ce8 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -5,9 +5,11 @@ add_public_tablegen_target(ONNXToSpatialIncGen) add_onnx_mlir_library(OMONNXToSpatial Math/Gemm.cpp Math/Conv.cpp + Math/MatMul.cpp NN/Pooling.cpp NN/ReduceMean.cpp Tensor/ONNXConcatToTensorConcat.cpp + Tensor/ONNXReshapeToTensorReshape.cpp Tensor/RemoveUnusedHelperOps.cpp Utils/SpatialReducer.cpp Utils/WeightSubdivider.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp index a9d4dfb..981510b 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp @@ -130,19 +130,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, }); Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})); - // Reshape bias [numChannelsOut] -> [1, numChannelsOut] for Gemm C row-broadcasting, or use none + // Pass bias through directly; Gemm handles rank-1 C canonicalization. bool hasB = !isa(b.getDefiningOp()); Value gemmC; - if (hasB) { - auto biasType = RankedTensorType::get({1, numChannelsOut}, cast(b.getType()).getElementType()); - gemmC = tensor::ExpandShapeOp::create(rewriter, - loc, - biasType, - b, - SmallVector { - {0, 1} - }); - } + if (hasB) + gemmC = b; else gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp index 53bde28..0ca061c 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Gemm.cpp @@ -23,6 +23,38 @@ namespace { constexpr StringRef COMPUTE_HAS_SOFTMAX_DIVISOR_ATTRNAME = "computeWithSoftmaxDivisor"; +static FailureOr materializeScaledConstantTensor(Value value, + float factor, + ConversionPatternRewriter& rewriter, + Location loc) { + if (factor == 1.0f) + return value; + + auto constantOp = value.getDefiningOp(); + if (!constantOp) + return failure(); + + auto denseAttr = dyn_cast(constantOp.getValue()); + if (!denseAttr) + return failure(); + + SmallVector scaledValues; + scaledValues.reserve(denseAttr.getNumElements()); + APFloat scale(factor); + bool hadFailure = false; + for (const APFloat& originalValue : denseAttr.getValues()) { + APFloat scaledValue(originalValue); + if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven)) + hadFailure = true; + scaledValues.push_back(std::move(scaledValue)); + } + if (hadFailure) + return failure(); + + auto scaledAttr = DenseFPElementsAttr::get(cast(denseAttr.getType()), scaledValues); + return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult(); +} + struct GemmToManyGemv : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -74,10 +106,25 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, if (numOutRows <= 1) return failure(); + auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc); + if (failed(scaledB)) + return failure(); + b = *scaledB; + RankedTensorType cType = nullptr; bool cHasNumOutRows = false; if (hasC) { + auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc); + if (failed(scaledC)) + return failure(); + c = *scaledC; cType = cast(c.getType()); + // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling + if (cType.getRank() == 1) { + auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); + c = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, c, SmallVector{{0, 1}}); + cType = expandedType; + } assert("Only support rank 2 tensor for C" && cType.getRank() == 2); cHasNumOutRows = cType.getDimSize(0) == numOutRows; } @@ -112,8 +159,8 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, aSlice, b, cSlice, - gemmOp.getAlphaAttr(), - gemmOp.getBetaAttr(), + rewriter.getF32FloatAttr(1.0f), + rewriter.getF32FloatAttr(1.0f), gemmOp.getTransAAttr(), gemmOp.getTransBAttr()); gemvOps.push_back(gemvOp.getY()); @@ -158,6 +205,12 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, bool hasC = !isa(c.getDefiningOp()); if (hasC) { cType = cast(c.getType()); + // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling + if (cType.getRank() == 1) { + auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); + c = tensor::ExpandShapeOp::create(rewriter, gemmLoc, expandedType, c, SmallVector{{0, 1}}); + cType = expandedType; + } assert("Only support rank 2 tensor for C" && cType.getRank() == 2); } @@ -177,19 +230,24 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, auto bShape = bType.getShape(); auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); + bType = cast(b.getType()); } if (alpha != 1.0f) { - auto alphaTensorType = RankedTensorType::get({1, 1}, cast(a.getType()).getElementType()); - auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha}); - auto alphaTensor = arith::ConstantOp::create(rewriter, gemmLoc, alphaTensorType, alphaTensorValue); - a = spatial::SpatVMulOp::create(rewriter, gemmLoc, a.getType(), a, alphaTensor); + auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc); + if (failed(scaledB)) + return failure(); + b = *scaledB; + bType = cast(b.getType()); + alpha = 1.0f; } if (hasC && beta != 1.0f) { - auto betaTensorType = RankedTensorType::get({1, 1}, cast(c.getType()).getElementType()); - auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta}); - auto betaTensor = arith::ConstantOp::create(rewriter, gemmLoc, betaTensorType, betaTensorValue); - c = spatial::SpatVMulOp::create(rewriter, gemmLoc, c.getType(), c, betaTensor); + auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc); + if (failed(scaledC)) + return failure(); + c = *scaledC; + cType = cast(c.getType()); + beta = 1.0f; } auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/MatMul.cpp new file mode 100644 index 0000000..a37cd74 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Math/MatMul.cpp @@ -0,0 +1,108 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" + +#include "llvm/ADT/SmallVector.h" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +struct MatMulRank3ToGemm : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { + auto lhsType = dyn_cast(matmulOp.getA().getType()); + auto rhsType = dyn_cast(matmulOp.getB().getType()); + auto outType = dyn_cast(matmulOp.getY().getType()); + if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() + || !outType.hasStaticShape()) + return failure(); + if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3) + return failure(); + + const int64_t batch = rhsType.getDimSize(0); + const int64_t k = rhsType.getDimSize(1); + const int64_t n = rhsType.getDimSize(2); + const int64_t m = lhsType.getDimSize(0); + if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m + || outType.getDimSize(2) != n) + return failure(); + + Location loc = matmulOp.getLoc(); + auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType()); + auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType()); + auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType()); + auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType()); + auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType()); + auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType()); + + Value lhsTransposed = + ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0})); + Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); + + SmallVector gemmRows; + gemmRows.reserve(batch * n); + for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) { + for (int64_t colIdx = 0; colIdx < n; colIdx++) { + SmallVector offsets = { + rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)}; + SmallVector sizes = { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)}; + SmallVector strides = { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Value rhsSlice = + tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides); + Value rhsRow = tensor::CollapseShapeOp::create( + rewriter, loc, rhsRowType, rhsSlice, SmallVector{{0}, {1, 2}}); + + auto gemmOp = ONNXGemmOp::create(rewriter, + loc, + gemmRowType, + rhsRow, + lhsTransposed, + none, + rewriter.getF32FloatAttr(1.0f), + rewriter.getF32FloatAttr(1.0f), + rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)); + gemmRows.push_back(gemmOp.getY()); + } + } + + auto concatComputeOp = + spatial::SpatWeightedCompute::create(rewriter, loc, gemmOutType, SmallVector(), gemmRows); + + auto* concatBlock = new Block(); + for (Value gemmRow : gemmRows) + concatBlock->addArgument(gemmRow.getType(), loc); + concatComputeOp.getBody().push_back(concatBlock); + rewriter.setInsertionPointToStart(concatBlock); + + auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, concatBlock->getArguments()); + spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult()); + + rewriter.setInsertionPointAfter(concatComputeOp); + Value gemmOut = concatComputeOp.getResult(0); + Value gemmExpanded = tensor::ExpandShapeOp::create( + rewriter, loc, gemmExpandedType, gemmOut, SmallVector{{0, 1}, {2}}); + Value result = ONNXTransposeOp::create( + rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1})); + + rewriter.replaceOp(matmulOp, result); + return success(); + } +}; + +} // namespace + +void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td index f83d758..cb3401b 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.td @@ -15,6 +15,10 @@ def onnxToArithConstantOp : Pat< // ONNXMatMulOp to ONNXGemmOp patterns +def IsRank2Result: Constraint< + CPred<"cast($0.getType()).getRank() == 2">, + "Result is rank 2">; + def matMulAddToGemmPattern : Pat< (ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C), (ONNXGemmOp $A, $B, $C, @@ -22,19 +26,21 @@ def matMulAddToGemmPattern : Pat< /* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">), /* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">), /* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">) - ) + ), + [(IsRank2Result $matmulres)] >; def matMulToGemmPattern : Pat< (ONNXMatMulOp:$matmulres $A, $B), ( - ONNXGemmOp $A, $B, + ONNXGemmOp $A, $B, /* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast(matmulres.getY().getType()).getShape(), cast(matmulres.getY().getType()).getElementType());">), /* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">), /* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">), /* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">), /* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">) - ) + ), + [(IsRank2Result $matmulres)] >; // ONNXConvOp + ONNXAddOp to ONNXConvOp pattern diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 0bd82d0..02f92c8 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -1,6 +1,5 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -56,6 +55,7 @@ void ONNXToSpatialPass::runOnOperation() { mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); + populateMatMulRewritePatterns(mergeActivationPatterns, ctx); if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns)))) llvm::dbgs() << "Failed to merge activation patterns, continuing...\n"; @@ -74,7 +74,9 @@ void ONNXToSpatialPass::runOnOperation() { ConversionTarget target(*ctx); target.addLegalDialect(); - target.addIllegalOp(); + target.addDynamicallyLegalOp([](ONNXMatMulOp op) { + return cast(op.getY().getType()).getRank() != 2; + }); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -83,6 +85,7 @@ void ONNXToSpatialPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); RewritePatternSet patterns(ctx); patterns.add(ctx); @@ -90,6 +93,7 @@ void ONNXToSpatialPass::runOnOperation() { populateConvOpPatterns(patterns, ctx); populatePoolingTilingPattern(patterns, ctx); populateOnnxGemmOpPatterns(patterns, ctx); + populateReshapeConversionPattern(patterns, ctx); populateONNXConcatToTensorConcatPattern(patterns, ctx); populateReduceMeanConversionPattern(patterns, ctx); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp index 1a1a3f6..1b7ffaa 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp @@ -7,12 +7,16 @@ namespace onnx_mlir { void populateConvOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateReshapeConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); diff --git a/src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXReshapeToTensorReshape.cpp b/src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXReshapeToTensorReshape.cpp new file mode 100644 index 0000000..629b0ce --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXReshapeToTensorReshape.cpp @@ -0,0 +1,121 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static bool haveStaticPositiveShape(ArrayRef shape) { + return llvm::all_of(shape, [](int64_t dim) { return dim > 0; }); +} + +static bool inferCollapseReassociation(ArrayRef sourceShape, + ArrayRef resultShape, + SmallVector& reassociation) { + reassociation.clear(); + + size_t sourceIdx = 0; + size_t resultIdx = 0; + while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) { + int64_t sourceProduct = sourceShape[sourceIdx]; + int64_t resultProduct = resultShape[resultIdx]; + + ReassociationIndices group; + group.push_back(sourceIdx); + while (sourceProduct != resultProduct) { + if (sourceProduct > resultProduct) + return false; + sourceIdx++; + if (sourceIdx >= sourceShape.size()) + return false; + group.push_back(sourceIdx); + sourceProduct *= sourceShape[sourceIdx]; + } + + reassociation.push_back(group); + sourceIdx++; + resultIdx++; + } + + return sourceIdx == sourceShape.size() && resultIdx == resultShape.size(); +} + +static bool inferExpandReassociation(ArrayRef sourceShape, + ArrayRef resultShape, + SmallVector& reassociation) { + reassociation.clear(); + + size_t sourceIdx = 0; + size_t resultIdx = 0; + while (sourceIdx < sourceShape.size() && resultIdx < resultShape.size()) { + int64_t sourceProduct = sourceShape[sourceIdx]; + int64_t resultProduct = resultShape[resultIdx]; + + ReassociationIndices group; + group.push_back(resultIdx); + while (resultProduct != sourceProduct) { + if (resultProduct > sourceProduct) + return false; + resultIdx++; + if (resultIdx >= resultShape.size()) + return false; + group.push_back(resultIdx); + resultProduct *= resultShape[resultIdx]; + } + + reassociation.push_back(group); + sourceIdx++; + resultIdx++; + } + + return sourceIdx == sourceShape.size() && resultIdx == resultShape.size(); +} + +struct ONNXReshapeToTensorReshape : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXReshapeOp reshapeOp, + ONNXReshapeOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto sourceType = dyn_cast(adaptor.getData().getType()); + auto resultType = dyn_cast(reshapeOp.getReshaped().getType()); + if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) + return failure(); + if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape())) + return failure(); + + if (sourceType == resultType) { + rewriter.replaceOp(reshapeOp, adaptor.getData()); + return success(); + } + + SmallVector reassociation; + if (sourceType.getRank() > resultType.getRank() + && inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) { + rewriter.replaceOpWithNewOp(reshapeOp, resultType, adaptor.getData(), reassociation); + return success(); + } + + if (sourceType.getRank() < resultType.getRank() + && inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) { + rewriter.replaceOpWithNewOp(reshapeOp, resultType, adaptor.getData(), reassociation); + return success(); + } + + return failure(); + } +}; + +} // namespace + +void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.insert(ctx); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 7490fb2..2ffdb7b 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -1,3 +1,4 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -79,8 +80,31 @@ private: } // namespace static bool isChannelUseChainOp(Operation* op) { - return isa( - op); + return isa(op); +} + +static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) { + for (Value operand : op->getOperands()) { + if (mapping.lookupOrNull(operand)) + continue; + + Operation* definingOp = operand.getDefiningOp(); + if (!definingOp) + continue; + + if (!isa(definingOp)) + continue; + + Operation* clonedOp = rewriter.clone(*definingOp, mapping); + for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults())) + mapping.map(originalResult, newResult); + rewriter.setInsertionPointAfter(clonedOp); + } } static size_t countComputeLeafUsers(Value value) { @@ -204,6 +228,56 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR OpOperand& resultUse = *resultUses.begin(); Operation* resultUser = resultUse.getOwner(); + if (isChannelUseChainOp(resultUser)) { + SmallVector returnChain; + Value chainedValue = result; + Operation* chainUser = resultUser; + + while (isChannelUseChainOp(chainUser)) { + returnChain.push_back(chainUser); + auto chainUses = chainUser->getResult(0).getUses(); + if (rangeLength(chainUses) != 1) + break; + chainedValue = chainUser->getResult(0); + chainUser = chainUses.begin()->getOwner(); + } + + if (isa(chainUser)) { + size_t resultIndexInReturn = chainedValue.getUses().begin()->getOperandNumber(); + + rewriter.setInsertionPoint(yieldOp); + IRMapping mapping; + mapping.map(result, yieldValue); + + Value storedValue = yieldValue; + for (Operation* op : returnChain) { + cloneMappedHelperOperands(op, mapping, rewriter); + Operation* clonedOp = rewriter.clone(*op, mapping); + for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) + mapping.map(originalResult, newResult); + storedValue = clonedOp->getResult(0); + rewriter.setInsertionPointAfter(clonedOp); + markOpToRemove(op); + } + + auto storedType = cast(storedValue.getType()); + size_t elementSize = storedType.getElementTypeBitWidth() / 8; + + Value outputTensor = outputTensors[resultIndexInReturn]; + if (auto storedOp = storedValue.getDefiningOp()) + rewriter.setInsertionPointAfter(storedOp); + PimMemCopyDevToHostOp::create(rewriter, + loc, + outputTensor.getType(), + outputTensor, + storedValue, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(storedType.getNumElements() * elementSize)); + continue; + } + } + if (isa(resultUser)) { size_t resultIndexInReturn = resultUse.getOperandNumber(); size_t offset = 0; @@ -493,6 +567,7 @@ void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu IRMapping mapping; mapping.map(channelSourceOp, receivedValue); for (Operation* op : llvm::reverse(clonedChain)) { + cloneMappedHelperOperands(op, mapping, rewriter); Operation* clonedOp = rewriter.clone(*op, mapping); for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) mapping.map(originalResult, newResult); diff --git a/src/PIM/Pass/PimConstantFoldingPass.cpp b/src/PIM/Pass/PimConstantFoldingPass.cpp index 9fe6a2f..bf70814 100644 --- a/src/PIM/Pass/PimConstantFoldingPass.cpp +++ b/src/PIM/Pass/PimConstantFoldingPass.cpp @@ -30,6 +30,24 @@ static Value stripMemRefCasts(Value value) { return value; } +static Value stripMemRefViewOps(Value value) { + while (true) { + if (auto castOp = value.getDefiningOp()) { + value = castOp.getSource(); + continue; + } + if (auto collapseOp = value.getDefiningOp()) { + value = collapseOp.getSrc(); + continue; + } + if (auto expandOp = value.getDefiningOp()) { + value = expandOp.getSrc(); + continue; + } + return value; + } +} + static memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp, Location loc, MemRefType globalType, @@ -204,6 +222,7 @@ struct StaticSubviewInfo { }; static FailureOr getStaticSubviewInfo(Value value) { + value = stripMemRefViewOps(value); auto subviewOp = value.getDefiningOp(); if (!subviewOp) return failure(); @@ -321,6 +340,77 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern } }; +struct RewriteHostSubviewLoadPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pim::PimMemCopyHostToDevOp copyOp, PatternRewriter& rewriter) const override { + auto srcSubview = getStaticSubviewInfo(copyOp.getHostSrc()); + auto dstSubview = getStaticSubviewInfo(copyOp.getDeviceDst()); + const bool splitSrc = succeeded(srcSubview) + && !isMemoryContiguous(srcSubview->sourceShape, srcSubview->offsets, srcSubview->sizes, srcSubview->strides); + const bool splitDst = succeeded(dstSubview) + && !isMemoryContiguous(dstSubview->sourceShape, dstSubview->offsets, dstSubview->sizes, dstSubview->strides); + if (!splitSrc && !splitDst) + return failure(); + + auto sourceType = dyn_cast(copyOp.getHostSrc().getType()); + auto dstType = dyn_cast(copyOp.getDeviceDst().getType()); + if (!sourceType || !dstType || !sourceType.hasStaticShape() || !dstType.hasStaticShape()) + return failure(); + if (sourceType.getElementType() != dstType.getElementType()) + return failure(); + + if (splitSrc && llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) + return failure(); + if (splitDst && llvm::any_of(dstSubview->strides, [](int64_t stride) { return stride != 1; })) + return failure(); + + ArrayRef copyShape = splitSrc ? ArrayRef(srcSubview->sizes) : ArrayRef(dstSubview->sizes); + if (splitSrc && splitDst && copyShape != ArrayRef(dstSubview->sizes)) + return failure(); + + const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8; + if (elementByteWidth <= 0) + return failure(); + + const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth; + if (copyOp.getSize() != totalBytes) + return failure(); + + const int64_t sliceBytes = copyShape.back() * elementByteWidth; + if (sliceBytes <= 0) + return failure(); + + SmallVector outerShape(copyShape.begin(), copyShape.end() - 1); + auto outerStrides = computeRowMajorStrides(outerShape); + const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape); + + rewriter.setInsertionPoint(copyOp); + for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) { + SmallVector outerIndices = + outerShape.empty() ? SmallVector{} : delinearizeIndex(linearIndex, outerShape, outerStrides); + const int64_t srcByteOffset = copyOp.getHostSrcOffset() + + (splitSrc ? getSubviewChunkOffsetBytes(*srcSubview, outerIndices, elementByteWidth) + : linearIndex * sliceBytes); + const int64_t dstByteOffset = copyOp.getDeviceDstOffset() + + (splitDst ? getSubviewChunkOffsetBytes(*dstSubview, outerIndices, elementByteWidth) + : linearIndex * sliceBytes); + pim::PimMemCopyHostToDevOp::create( + rewriter, + copyOp.getLoc(), + splitDst ? cast(dstSubview->source.getType()) : dstType, + splitDst ? dstSubview->source : copyOp.getDeviceDst(), + splitSrc ? srcSubview->source : copyOp.getHostSrc(), + rewriter.getI32IntegerAttr(static_cast(dstByteOffset)), + rewriter.getI32IntegerAttr(static_cast(srcByteOffset)), + rewriter.getI32IntegerAttr(static_cast(sliceBytes))); + } + + rewriter.replaceOp(copyOp, copyOp.getDeviceDst()); + return success(); + } +}; + static FailureOr foldConstantAlloc(memref::AllocOp allocOp, ModuleOp moduleOp) { auto allocType = dyn_cast(allocOp.getType()); if (!allocType || !allocType.hasStaticShape()) @@ -578,6 +668,170 @@ struct FoldConstantAllocPattern final : OpRewritePattern { } }; +struct FoldConstantMemCpPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(pim::PimMemCopyOp copyOp, PatternRewriter& rewriter) const override { + // Only match top-level memcp (not inside pim.core) + if (copyOp->getParentOfType()) + return failure(); + + // dst must be an alloc with static shape + auto allocOp = copyOp.getDst().getDefiningOp(); + if (!allocOp) + return failure(); + auto allocType = dyn_cast(allocOp.getType()); + if (!allocType || !allocType.hasStaticShape()) + return failure(); + + // The copy must cover the full destination (offsets both zero) + if (copyOp.getDstOffset() != 0 || copyOp.getSrcOffset() != 0) + return failure(); + + // Resolve the source through an optional subview to a get_global + auto srcSubview = getStaticSubviewInfo(copyOp.getSrc()); + Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSrc()); + + auto moduleOp = copyOp->getParentOfType(); + if (!moduleOp) + return failure(); + + auto denseAttr = getDenseGlobalValue(moduleOp, globalSource); + if (failed(denseAttr)) + return failure(); + + // Build the folded dense attribute + DenseElementsAttr foldedAttr; + if (succeeded(srcSubview)) { + // Extract the sub-tensor from the source constant + auto sourceType = dyn_cast(denseAttr->getType()); + if (!sourceType || !sourceType.hasStaticShape()) + return failure(); + if (llvm::any_of(srcSubview->strides, [](int64_t s) { return s != 1; })) + return failure(); + + auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); + const int64_t numResultElements = resultTensorType.getNumElements(); + auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); + auto resultStrides = computeRowMajorStrides(resultTensorType.getShape()); + SmallVector sourceValues(denseAttr->getValues()); + SmallVector resultValues(numResultElements); + + for (int64_t i = 0; i < numResultElements; ++i) { + auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides); + SmallVector sourceIndices; + sourceIndices.reserve(resultIndices.size()); + for (auto [off, idx] : llvm::zip_equal(srcSubview->offsets, resultIndices)) + sourceIndices.push_back(off + idx); + int64_t srcLinear = linearizeIndex(sourceIndices, sourceStrides); + resultValues[i] = sourceValues[srcLinear]; + } + foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues); + } + else { + // Direct copy from a global — just reuse its dense attribute + auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); + if (resultTensorType != denseAttr->getType()) + return failure(); + foldedAttr = *denseAttr; + } + + // Verify that the alloc's remaining users are supported ops. + bool allLiveUsersAreCores = true; + for (Operation* user : allocOp->getUsers()) { + if (user == copyOp) + continue; + if (isa(user)) + continue; + if (isa(user)) + continue; + if (isa(user)) { + allLiveUsersAreCores = false; + continue; + } + return failure(); + } + + auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp"); + if (allLiveUsersAreCores) + markWeightAlways(newGlobal); + + rewriter.setInsertionPoint(allocOp); + auto newGetGlobal = memref::GetGlobalOp::create(rewriter, allocOp.getLoc(), allocType, newGlobal.getName()); + if (allLiveUsersAreCores) + markWeightAlways(newGetGlobal); + + rewriter.replaceAllUsesWith(allocOp.getResult(), newGetGlobal.getResult()); + rewriter.eraseOp(copyOp); + if (allocOp.use_empty()) + rewriter.eraseOp(allocOp); + return success(); + } +}; + +struct FoldConstantCoreSubviewPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::SubViewOp subviewOp, PatternRewriter& rewriter) const override { + // Only handle subviews whose users are all pim.core ops. + if (subviewOp.use_empty()) + return failure(); + if (!llvm::all_of(subviewOp->getUsers(), [](Operation* user) { return isa(user); })) + return failure(); + + // Source must resolve to a constant get_global. + auto moduleOp = subviewOp->getParentOfType(); + if (!moduleOp) + return failure(); + auto denseAttr = getDenseGlobalValue(moduleOp, stripMemRefCasts(subviewOp.getSource())); + if (failed(denseAttr)) + return failure(); + + // Static subview info. + auto subviewInfo = getStaticSubviewInfo(subviewOp.getResult()); + if (failed(subviewInfo)) + return failure(); + if (llvm::any_of(subviewInfo->strides, [](int64_t s) { return s != 1; })) + return failure(); + + auto sourceType = dyn_cast(denseAttr->getType()); + if (!sourceType || !sourceType.hasStaticShape()) + return failure(); + + // Build the contiguous result type. + auto elementType = cast(subviewOp.getType()).getElementType(); + auto resultMemRefType = MemRefType::get( + SmallVector(subviewInfo->sizes.begin(), subviewInfo->sizes.end()), elementType); + auto resultTensorType = RankedTensorType::get(resultMemRefType.getShape(), elementType); + const int64_t numResultElements = resultTensorType.getNumElements(); + + // Extract the sub-tensor. + auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); + auto resultStrides = computeRowMajorStrides(resultTensorType.getShape()); + SmallVector sourceValues(denseAttr->getValues()); + SmallVector resultValues(numResultElements); + for (int64_t i = 0; i < numResultElements; ++i) { + auto resultIndices = delinearizeIndex(i, resultTensorType.getShape(), resultStrides); + SmallVector sourceIndices; + sourceIndices.reserve(resultIndices.size()); + for (auto [off, idx] : llvm::zip_equal(subviewInfo->offsets, resultIndices)) + sourceIndices.push_back(off + idx); + resultValues[i] = sourceValues[linearizeIndex(sourceIndices, sourceStrides)]; + } + auto foldedAttr = DenseElementsAttr::get(resultTensorType, resultValues); + + auto newGlobal = createFoldedGlobal(moduleOp, subviewOp.getLoc(), resultMemRefType, foldedAttr, "pim_folded_subview"); + markWeightAlways(newGlobal); + + rewriter.setInsertionPoint(subviewOp); + auto newGetGlobal = memref::GetGlobalOp::create(rewriter, subviewOp.getLoc(), resultMemRefType, newGlobal.getName()); + markWeightAlways(newGetGlobal); + + rewriter.replaceOp(subviewOp, newGetGlobal.getResult()); + return success(); + } +}; + struct PimConstantFoldingPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass) @@ -591,7 +845,13 @@ struct PimConstantFoldingPass : PassWrappergetRegisteredOperations()) op.getCanonicalizationPatterns(owningPatterns, context); owningPatterns - .add( + .add( context); patterns = std::make_shared(std::move(owningPatterns)); return success(); diff --git a/validation/operations/README.md b/validation/operations/README.md new file mode 100644 index 0000000..03c2a5c --- /dev/null +++ b/validation/operations/README.md @@ -0,0 +1,47 @@ +# Validation Operations + +ONNX test models used by `validate.py` to verify the Raptor compiler + PIM simulator pipeline. + +Generated tests can be regenerated with: +``` +python3 validation/operations/gen_tests.py +``` + +## Conv + +| Test | Directory | Input | Output | Kernel | Stride | Padding | Bias | Notes | +|------|-----------|-------|--------|--------|--------|---------|------|-------| +| Simple | `conv/simple` | [1,3,3,3] | [1,1,2,2] | 2x2 | 1 | none | no | Basic conv, hand-crafted | +| With constant | `conv/with_constant` | [1,3,3,3] | [1,1,3,3] | 2x2 | 1 | SAME_UPPER | yes | Hand-crafted, constant weight+bias | +| Batch 2 | `conv/batch_2` | [2,3,3,3] | [2,1,3,3] | 2x2 | 1 | SAME_UPPER | yes | Batched input | +| Kernel 3x3 | `conv/kernel_3x3` | [1,1,5,5] | [1,1,3,3] | 3x3 | 1 | none | no | Larger kernel | +| Stride 2 | `conv/stride_2` | [1,1,6,6] | [1,1,2,2] | 3x3 | 2 | none | no | Strided convolution | +| Multi channel | `conv/multi_channel` | [1,3,5,5] | [1,4,3,3] | 3x3 | 1 | none | no | 3 in channels, 4 out channels | +| Pointwise 1x1 | `conv/pointwise_1x1` | [1,8,4,4] | [1,4,4,4] | 1x1 | 1 | none | no | Channel mixing | +| SAME padding 3x3 | `conv/same_padding_3x3` | [1,1,5,5] | [1,1,5,5] | 3x3 | 1 | SAME_UPPER | no | Spatial dims preserved | +| Explicit padding | `conv/explicit_padding` | [1,1,4,4] | [1,1,4,4] | 3x3 | 1 | [1,1,1,1] | no | Symmetric explicit pads | +| With bias 3x3 | `conv/with_bias_3x3` | [1,3,5,5] | [1,2,3,3] | 3x3 | 1 | none | yes | Multi-channel with bias | +| Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input | + +## Gemm + +| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes | +|------|-----------|-----------|------------|--------|--------|-------|------|------|-------| +| Default | `gemm/` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Hand-crafted, square weights | +| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N | +| With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector | +| transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight | +| Alpha/beta | `gemm/alpha_beta` | [4,64] | [64,64] | [4,64] | no | 0.5 | 0.25 | [64] | Scaled matmul + bias | +| Small | `gemm/small` | [2,8] | [8,4] | [2,4] | no | 1 | 1 | no | Tiny matrices | +| Large | `gemm/large` | [8,256] | [256,128] | [8,128] | no | 1 | 1 | no | Larger matrices | +| transB + bias | `gemm/transB_with_bias` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | [64] | Combined | + +## Gemv + +| Test | Directory | Input | W (weight) | Output | Bias | Notes | +|------|-----------|-------|------------|--------|------|-------| +| Simple | `gemv/simple` | [1,132] | [132,132] | [1,132] | no | Single-sample matmul | +| Constant | `gemv/constant` | _(none)_ | [132,132] | [1,132] | no | All inputs constant | +| Homogeneous const | `gemv/with_homogeneous_constant` | [1,132] | [132,132] | [1,132] | [1,132] | Bias matches output shape | +| Heterogeneous const | `gemv/with_heterogeneous_constant` | [1,132] | [132,132] | [1,132] | [1,132] | Different constant pattern | +| Scalar const | `gemv/with_scalar_constant` | [1,132] | [132,132] | [1,132] | [1,1] | Scalar bias, broadcast | diff --git a/validation/operations/conv/batch_2/conv_batch_2.onnx b/validation/operations/conv/batch_2/conv_batch_2.onnx new file mode 100644 index 0000000..8fa369c Binary files /dev/null and b/validation/operations/conv/batch_2/conv_batch_2.onnx differ diff --git a/validation/operations/conv/batch_64/conv_batch_64.onnx b/validation/operations/conv/batch_64/conv_batch_64.onnx deleted file mode 100644 index d5775f7..0000000 Binary files a/validation/operations/conv/batch_64/conv_batch_64.onnx and /dev/null differ diff --git a/validation/operations/conv/explicit_padding/conv_explicit_padding.onnx b/validation/operations/conv/explicit_padding/conv_explicit_padding.onnx new file mode 100644 index 0000000..15813f7 Binary files /dev/null and b/validation/operations/conv/explicit_padding/conv_explicit_padding.onnx differ diff --git a/validation/operations/conv/kernel_3x3/conv_kernel_3x3.onnx b/validation/operations/conv/kernel_3x3/conv_kernel_3x3.onnx new file mode 100644 index 0000000..701fc94 Binary files /dev/null and b/validation/operations/conv/kernel_3x3/conv_kernel_3x3.onnx differ diff --git a/validation/operations/conv/large_spatial/conv_large_spatial.onnx b/validation/operations/conv/large_spatial/conv_large_spatial.onnx new file mode 100644 index 0000000..3e3d862 Binary files /dev/null and b/validation/operations/conv/large_spatial/conv_large_spatial.onnx differ diff --git a/validation/operations/conv/multi_channel/conv_multi_channel.onnx b/validation/operations/conv/multi_channel/conv_multi_channel.onnx new file mode 100644 index 0000000..bfec7b2 Binary files /dev/null and b/validation/operations/conv/multi_channel/conv_multi_channel.onnx differ diff --git a/validation/operations/conv/pointwise_1x1/conv_1x1.onnx b/validation/operations/conv/pointwise_1x1/conv_1x1.onnx new file mode 100644 index 0000000..50cf143 Binary files /dev/null and b/validation/operations/conv/pointwise_1x1/conv_1x1.onnx differ diff --git a/validation/operations/conv/same_padding_3x3/conv_same_padding_3x3.onnx b/validation/operations/conv/same_padding_3x3/conv_same_padding_3x3.onnx new file mode 100644 index 0000000..3a017fc Binary files /dev/null and b/validation/operations/conv/same_padding_3x3/conv_same_padding_3x3.onnx differ diff --git a/validation/operations/conv/stride_2/conv_stride_2.onnx b/validation/operations/conv/stride_2/conv_stride_2.onnx new file mode 100644 index 0000000..9135966 Binary files /dev/null and b/validation/operations/conv/stride_2/conv_stride_2.onnx differ diff --git a/validation/operations/conv/with_bias_3x3/conv_with_bias_3x3.onnx b/validation/operations/conv/with_bias_3x3/conv_with_bias_3x3.onnx new file mode 100644 index 0000000..d4f81e4 Binary files /dev/null and b/validation/operations/conv/with_bias_3x3/conv_with_bias_3x3.onnx differ diff --git a/validation/operations/gemm/alpha_beta/gemm_alpha_beta.onnx b/validation/operations/gemm/alpha_beta/gemm_alpha_beta.onnx new file mode 100644 index 0000000..1248276 Binary files /dev/null and b/validation/operations/gemm/alpha_beta/gemm_alpha_beta.onnx differ diff --git a/validation/operations/gemm/large/gemm_large.onnx b/validation/operations/gemm/large/gemm_large.onnx new file mode 100644 index 0000000..d44a8db Binary files /dev/null and b/validation/operations/gemm/large/gemm_large.onnx differ diff --git a/validation/operations/gemm/non_square/gemm_non_square.onnx b/validation/operations/gemm/non_square/gemm_non_square.onnx new file mode 100644 index 0000000..d26ecf4 Binary files /dev/null and b/validation/operations/gemm/non_square/gemm_non_square.onnx differ diff --git a/validation/operations/gemm/small/gemm_small.onnx b/validation/operations/gemm/small/gemm_small.onnx new file mode 100644 index 0000000..49a7f91 Binary files /dev/null and b/validation/operations/gemm/small/gemm_small.onnx differ diff --git a/validation/operations/gemm/transB/gemm_transB.onnx b/validation/operations/gemm/transB/gemm_transB.onnx new file mode 100644 index 0000000..2df2d83 Binary files /dev/null and b/validation/operations/gemm/transB/gemm_transB.onnx differ diff --git a/validation/operations/gemm/transB_with_bias/gemm_transB_with_bias.onnx b/validation/operations/gemm/transB_with_bias/gemm_transB_with_bias.onnx new file mode 100644 index 0000000..696ef58 Binary files /dev/null and b/validation/operations/gemm/transB_with_bias/gemm_transB_with_bias.onnx differ diff --git a/validation/operations/gemm/with_bias/gemm_with_bias.onnx b/validation/operations/gemm/with_bias/gemm_with_bias.onnx new file mode 100644 index 0000000..b1b4bb2 Binary files /dev/null and b/validation/operations/gemm/with_bias/gemm_with_bias.onnx differ diff --git a/validation/operations/gen_tests.py b/validation/operations/gen_tests.py new file mode 100644 index 0000000..812edec --- /dev/null +++ b/validation/operations/gen_tests.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +"""Generate ONNX test models for validating GEMM and Conv implementations.""" + +import numpy as np +import onnx +from onnx import helper, TensorProto, numpy_helper +from pathlib import Path + +OPERATIONS_DIR = Path(__file__).parent + + +def save_model(model, directory, filename): + """Save an ONNX model, creating the directory if needed.""" + d = OPERATIONS_DIR / directory + d.mkdir(parents=True, exist_ok=True) + path = d / filename + onnx.checker.check_model(model) + onnx.save(model, str(path)) + print(f" {path.relative_to(OPERATIONS_DIR)}") + + +# --------------------------------------------------------------------------- +# GEMM tests +# --------------------------------------------------------------------------- + +def gemm_non_square(): + """GEMM with non-square weight matrix: [B, K] @ [K, N], K != N.""" + B, K, N = 4, 128, 64 + W = numpy_helper.from_array(np.random.default_rng(42).uniform(-1, 1, (K, N)).astype(np.float32), name="W") + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N]) + node = helper.make_node("Gemm", ["A", "W"], ["Y"]) + graph = helper.make_graph([node], "gemm_non_square", [A], [Y], initializer=[W]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gemm/non_square", "gemm_non_square.onnx") + + +def gemm_with_bias(): + """GEMM with bias: Y = A @ W + C.""" + B, K, N = 4, 128, 128 + rng = np.random.default_rng(43) + W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W") + C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C") + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N]) + node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"]) + graph = helper.make_graph([node], "gemm_with_bias", [A], [Y], initializer=[W, C]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gemm/with_bias", "gemm_with_bias.onnx") + + +def gemm_transB(): + """GEMM with transB=1: Y = A @ W^T.""" + B, K, N = 4, 128, 64 + rng = np.random.default_rng(44) + # W stored as [N, K], transposed during computation + W = numpy_helper.from_array(rng.uniform(-1, 1, (N, K)).astype(np.float32), name="W") + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N]) + node = helper.make_node("Gemm", ["A", "W"], ["Y"], transB=1) + graph = helper.make_graph([node], "gemm_transB", [A], [Y], initializer=[W]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gemm/transB", "gemm_transB.onnx") + + +def gemm_alpha_beta(): + """GEMM with alpha and beta: Y = 0.5 * A @ W + 0.25 * C.""" + B, K, N = 4, 64, 64 + rng = np.random.default_rng(45) + W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W") + C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C") + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N]) + node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"], alpha=0.5, beta=0.25) + graph = helper.make_graph([node], "gemm_alpha_beta", [A], [Y], initializer=[W, C]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gemm/alpha_beta", "gemm_alpha_beta.onnx") + + +def gemm_small(): + """Small GEMM: [2, 8] @ [8, 4].""" + B, K, N = 2, 8, 4 + rng = np.random.default_rng(46) + W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W") + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N]) + node = helper.make_node("Gemm", ["A", "W"], ["Y"]) + graph = helper.make_graph([node], "gemm_small", [A], [Y], initializer=[W]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gemm/small", "gemm_small.onnx") + + +def gemm_large(): + """Larger GEMM: [8, 256] @ [256, 128].""" + B, K, N = 8, 256, 128 + rng = np.random.default_rng(47) + W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W") + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N]) + node = helper.make_node("Gemm", ["A", "W"], ["Y"]) + graph = helper.make_graph([node], "gemm_large", [A], [Y], initializer=[W]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gemm/large", "gemm_large.onnx") + + +def gemm_transB_with_bias(): + """GEMM with transB and bias: Y = A @ W^T + C.""" + B, K, N = 4, 128, 64 + rng = np.random.default_rng(48) + W = numpy_helper.from_array(rng.uniform(-1, 1, (N, K)).astype(np.float32), name="W") + C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C") + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N]) + node = helper.make_node("Gemm", ["A", "W", "C"], ["Y"], transB=1) + graph = helper.make_graph([node], "gemm_transB_with_bias", [A], [Y], initializer=[W, C]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gemm/transB_with_bias", "gemm_transB_with_bias.onnx") + + +# --------------------------------------------------------------------------- +# Conv tests +# --------------------------------------------------------------------------- + +def conv_3x3_kernel(): + """Conv with 3x3 kernel, no padding.""" + # Input: [1, 1, 5, 5], Kernel: [1, 1, 3, 3] -> Output: [1, 1, 3, 3] + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3]) + W = numpy_helper.from_array( + np.random.default_rng(50).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W") + node = helper.make_node("Conv", ["X", "W"], ["Y"], + kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0]) + graph = helper.make_graph([node], "conv_3x3", [X], [Y], initializer=[W]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "conv/kernel_3x3", "conv_kernel_3x3.onnx") + + +def conv_stride2(): + """Conv with 3x3 kernel and stride 2.""" + # Input: [1, 1, 6, 6], Kernel: [1, 1, 3, 3], stride 2 -> Output: [1, 1, 2, 2] + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 6, 6]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 2, 2]) + W = numpy_helper.from_array( + np.random.default_rng(51).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W") + node = helper.make_node("Conv", ["X", "W"], ["Y"], + kernel_shape=[3, 3], strides=[2, 2], pads=[0, 0, 0, 0]) + graph = helper.make_graph([node], "conv_stride2", [X], [Y], initializer=[W]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "conv/stride_2", "conv_stride_2.onnx") + + +def conv_multi_channel(): + """Conv with multiple input and output channels.""" + # Input: [1, 3, 5, 5], Kernel: [4, 3, 3, 3] -> Output: [1, 4, 3, 3] + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 3, 3]) + W = numpy_helper.from_array( + np.random.default_rng(52).uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W") + node = helper.make_node("Conv", ["X", "W"], ["Y"], + kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0]) + graph = helper.make_graph([node], "conv_multi_channel", [X], [Y], initializer=[W]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "conv/multi_channel", "conv_multi_channel.onnx") + + +def conv_1x1(): + """1x1 pointwise convolution (channel mixing).""" + # Input: [1, 8, 4, 4], Kernel: [4, 8, 1, 1] -> Output: [1, 4, 4, 4] + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 8, 4, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 4, 4]) + W = numpy_helper.from_array( + np.random.default_rng(53).uniform(-1, 1, (4, 8, 1, 1)).astype(np.float32), name="W") + node = helper.make_node("Conv", ["X", "W"], ["Y"], + kernel_shape=[1, 1], strides=[1, 1], pads=[0, 0, 0, 0]) + graph = helper.make_graph([node], "conv_1x1", [X], [Y], initializer=[W]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "conv/pointwise_1x1", "conv_1x1.onnx") + + +def conv_same_padding_3x3(): + """Conv 3x3 with SAME_UPPER padding, preserving spatial dimensions.""" + # Input: [1, 1, 5, 5], Kernel: [1, 1, 3, 3], SAME_UPPER -> Output: [1, 1, 5, 5] + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5, 5]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 5, 5]) + W = numpy_helper.from_array( + np.random.default_rng(54).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W") + node = helper.make_node("Conv", ["X", "W"], ["Y"], + kernel_shape=[3, 3], strides=[1, 1], auto_pad="SAME_UPPER") + graph = helper.make_graph([node], "conv_same_3x3", [X], [Y], initializer=[W]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "conv/same_padding_3x3", "conv_same_padding_3x3.onnx") + + +def conv_explicit_padding(): + """Conv 3x3 with explicit asymmetric padding.""" + # Input: [1, 1, 4, 4], Kernel: [1, 1, 3, 3], pads=[1,1,1,1] -> Output: [1, 1, 4, 4] + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 4, 4]) + W = numpy_helper.from_array( + np.random.default_rng(55).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W") + node = helper.make_node("Conv", ["X", "W"], ["Y"], + kernel_shape=[3, 3], strides=[1, 1], pads=[1, 1, 1, 1]) + graph = helper.make_graph([node], "conv_explicit_pad", [X], [Y], initializer=[W]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "conv/explicit_padding", "conv_explicit_padding.onnx") + + +def conv_with_bias_3x3(): + """Conv 3x3 with bias.""" + # Input: [1, 3, 5, 5], Kernel: [2, 3, 3, 3], Bias: [2] -> Output: [1, 2, 3, 3] + rng = np.random.default_rng(56) + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 3, 3]) + W = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 3, 3)).astype(np.float32), name="W") + B = numpy_helper.from_array(rng.uniform(-1, 1, (2,)).astype(np.float32), name="B") + node = helper.make_node("Conv", ["X", "W", "B"], ["Y"], + kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0]) + graph = helper.make_graph([node], "conv_with_bias_3x3", [X], [Y], initializer=[W, B]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "conv/with_bias_3x3", "conv_with_bias_3x3.onnx") + + +def conv_batch_2(): + """Batched conv (batch=2) with SAME_UPPER padding and bias.""" + # Input: [2, 3, 3, 3], Kernel: [1, 3, 2, 2], Bias: [1] -> Output: [2, 1, 3, 3] + rng = np.random.default_rng(57) + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 3, 3]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 1, 3, 3]) + W = numpy_helper.from_array(rng.uniform(-1, 1, (1, 3, 2, 2)).astype(np.float32), name="W") + B = numpy_helper.from_array(rng.uniform(-1, 1, (1,)).astype(np.float32), name="B") + node = helper.make_node("Conv", ["X", "W", "B"], ["Y"], + kernel_shape=[2, 2], strides=[1, 1], auto_pad="SAME_UPPER") + graph = helper.make_graph([node], "conv_batch_2", [X], [Y], initializer=[W, B]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "conv/batch_2", "conv_batch_2.onnx") + + +def conv_large_spatial(): + """Conv on larger spatial input: [1, 1, 8, 8] with 3x3 kernel.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 8, 8]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 6, 6]) + W = numpy_helper.from_array( + np.random.default_rng(58).uniform(-1, 1, (1, 1, 3, 3)).astype(np.float32), name="W") + node = helper.make_node("Conv", ["X", "W"], ["Y"], + kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0]) + graph = helper.make_graph([node], "conv_large_spatial", [X], [Y], initializer=[W]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "conv/large_spatial", "conv_large_spatial.onnx") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + print("Generating GEMM tests:") + gemm_non_square() + gemm_with_bias() + gemm_transB() + gemm_alpha_beta() + gemm_small() + gemm_large() + gemm_transB_with_bias() + + print("\nGenerating Conv tests:") + conv_3x3_kernel() + conv_stride2() + conv_multi_channel() + conv_1x1() + conv_same_padding_3x3() + conv_explicit_padding() + conv_with_bias_3x3() + conv_batch_2() + conv_large_spatial() + + print("\nDone.")