diff --git a/AGENTS.md b/AGENTS.md index 6e4cdbc..f19d8ed 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -3,6 +3,7 @@ - `cmake --build ./build_release` - `cmake --build ./build_debug` - Never use `ninja` directly: it bypasses cmake's configuration and invalidates the build cache. +- Always tries the release version build first and ask before building with the debug version # Code changes diff --git a/backend-simulators/pim/pim-simulator/src/lib/cpu/mod.rs b/backend-simulators/pim/pim-simulator/src/lib/cpu/mod.rs index 30cdd09..0ab2f4a 100644 --- a/backend-simulators/pim/pim-simulator/src/lib/cpu/mod.rs +++ b/backend-simulators/pim/pim-simulator/src/lib/cpu/mod.rs @@ -1,3 +1,4 @@ +use crate::utility::AddressArg; use std::{collections::HashMap, fmt::Debug}; use anyhow::{Context, Result, ensure}; @@ -9,6 +10,7 @@ use crate::{ pub mod crossbar; + #[derive(Debug, Clone)] pub struct CPU<'a> { cores: Box<[Core<'a>]>, diff --git a/backend-simulators/pim/pim-simulator/src/lib/utility.rs b/backend-simulators/pim/pim-simulator/src/lib/utility.rs index 96e354d..b21a9d5 100644 --- a/backend-simulators/pim/pim-simulator/src/lib/utility.rs +++ b/backend-simulators/pim/pim-simulator/src/lib/utility.rs @@ -1,3 +1,4 @@ +use anyhow::{Result,Context}; use std::{fmt::Debug, mem::transmute}; diff --git a/src/PIM/Common/IR/ConstantUtils.cpp b/src/PIM/Common/IR/ConstantUtils.cpp index bcfe306..59357f1 100644 --- a/src/PIM/Common/IR/ConstantUtils.cpp +++ b/src/PIM/Common/IR/ConstantUtils.cpp @@ -1,8 +1,10 @@ +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/Matchers.h" #include "ConstantUtils.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -79,4 +81,24 @@ Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFo return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder); } +Value createAffineApplyOrFoldedConstant( + RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* anchorOp) { + SmallVector operandConstants; + operandConstants.reserve(operands.size()); + for (Value operand : operands) { + APInt constantValue; + if (!matchPattern(operand, m_ConstantInt(&constantValue))) + return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); + operandConstants.push_back(rewriter.getIndexAttr(constantValue.getSExtValue())); + } + + SmallVector foldedResults; + if (succeeded(map.constantFold(operandConstants, foldedResults))) { + if (auto constantResult = dyn_cast(foldedResults.front())) + return getOrCreateHostIndexConstant(anchorOp, constantResult.getInt(), rewriter); + } + + return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); +} + } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/ConstantUtils.hpp b/src/PIM/Common/IR/ConstantUtils.hpp index d5ea918..dc03959 100644 --- a/src/PIM/Common/IR/ConstantUtils.hpp +++ b/src/PIM/Common/IR/ConstantUtils.hpp @@ -1,5 +1,6 @@ #pragma once +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" @@ -29,4 +30,10 @@ mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder); +mlir::Value createAffineApplyOrFoldedConstant(mlir::RewriterBase& rewriter, + mlir::Location loc, + mlir::AffineMap map, + mlir::ValueRange operands, + mlir::Operation* anchorOp); + } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp index 52a1efb..b96541f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp @@ -1,8 +1,12 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" #include "llvm/ADT/SmallVector.h" +#include + #include "ShapeTilingUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" @@ -12,6 +16,72 @@ using namespace mlir; namespace onnx_mlir { +static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) { + if (auto attr = dyn_cast(result)) + return arith::ConstantIndexOp::create(rewriter, loc, cast(attr).getInt()).getResult(); + return cast(result); +} + +static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) { + APInt lhsConst; + if (matchPattern(lhs, m_ConstantInt(&lhsConst)) && lhsConst.isZero()) + return rhs; + + APInt rhsConst; + if (matchPattern(rhs, m_ConstantInt(&rhsConst)) && rhsConst.isZero()) + return lhs; + + return arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult(); +} + +static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatternRewriter& rewriter, Location loc) { + APInt factorConst; + if (auto attr = dyn_cast(factor)) + factorConst = cast(attr).getValue(); + else if (!matchPattern(cast(factor), m_ConstantInt(&factorConst))) + return arith::MulIOp::create(rewriter, loc, value, cast(factor)).getResult(); + + if (factorConst.isZero()) + return arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); + if (factorConst.isOne()) + return value; + + auto factorValue = arith::ConstantIndexOp::create(rewriter, loc, factorConst.getSExtValue()).getResult(); + return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult(); +} + +static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef strides) { + auto sourceType = dyn_cast(source.getType()); + if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() || sourceType.getRank() != resultType.getRank()) + return false; + + for (OpFoldResult stride : strides) { + APInt strideValue; + if (auto attr = dyn_cast(stride)) { + if (cast(attr).getInt() != 1) + return false; + continue; + } + if (!matchPattern(cast(stride), m_ConstantInt(&strideValue)) || !strideValue.isOne()) + return false; + } + + auto sizesAndShape = llvm::zip_equal(llvm::make_range(resultType.getShape().rbegin(), resultType.getShape().rend()), + llvm::make_range(sourceType.getShape().rbegin(), sourceType.getShape().rend())); + auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool { + auto [size, dimension] = sizeAndShape; + return size != dimension; + }); + if (firstDifferentSize == sizesAndShape.end()) + return true; + + ++firstDifferentSize; + return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) { + auto [size, _dimension] = sizeAndShape; + return size == 1; + }); +} + SmallVector sliceTensor( const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) { ArrayRef shape = getTensorShape(tensorToSlice); @@ -123,4 +193,87 @@ Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatte return broadcastCompute.getResult(0); } +Value materializeContiguousTensorSlice(Value source, + RankedTensorType resultType, + ArrayRef offsets, + ArrayRef strides, + ConversionPatternRewriter& rewriter, + Location loc) { + assert(resultType.hasStaticShape() && "expected static result type"); + size_t rank = static_cast(resultType.getRank()); + assert(offsets.size() == rank && "expected rank-matching offsets"); + assert(strides.size() == rank && "expected rank-matching strides"); + + SmallVector sizes; + sizes.reserve(resultType.getRank()); + for (int64_t size : resultType.getShape()) + sizes.push_back(rewriter.getIndexAttr(size)); + + if (isContiguousTensorSlice(source, resultType, strides)) + return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult(); + + if (resultType.getRank() == 0) + return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult(); + + Value init = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult(); + SmallVector zeroIndices(resultType.getRank()); + for (Value& zeroIndex : zeroIndices) + zeroIndex = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); + + SmallVector resultIndices; + resultIndices.reserve(resultType.getRank()); + + auto buildLoopNest = [&](auto&& self, unsigned dim, Value accumulator) -> Value { + if (dim == resultType.getRank()) { + SmallVector sourceIndices; + sourceIndices.reserve(resultType.getRank()); + for (unsigned idx = 0; idx < resultType.getRank(); ++idx) { + Value offsetValue = getIndexValue(offsets[idx], rewriter, loc); + Value scaledIndex = multiplyIndexValue(resultIndices[idx], strides[idx], rewriter, loc); + sourceIndices.push_back(addIndexValues(offsetValue, scaledIndex, rewriter, loc)); + } + + SmallVector sourceOffsets; + SmallVector destinationOffsets; + SmallVector unitSizes; + SmallVector unitStrides; + sourceOffsets.reserve(resultType.getRank()); + destinationOffsets.reserve(resultType.getRank()); + unitSizes.reserve(resultType.getRank()); + unitStrides.reserve(resultType.getRank()); + for (Value index : sourceIndices) + sourceOffsets.push_back(index); + for (Value index : resultIndices) + destinationOffsets.push_back(index); + for (int64_t idx = 0; idx < resultType.getRank(); ++idx) { + unitSizes.push_back(rewriter.getIndexAttr(1)); + unitStrides.push_back(rewriter.getIndexAttr(1)); + } + + auto elementTensorType = + RankedTensorType::get(SmallVector(resultType.getRank(), 1), resultType.getElementType()); + Value elementSlice = + tensor::ExtractSliceOp::create(rewriter, loc, elementTensorType, source, sourceOffsets, unitSizes, unitStrides) + .getResult(); + return tensor::InsertSliceOp::create( + rewriter, loc, elementSlice, accumulator, destinationOffsets, unitSizes, unitStrides) + .getResult(); + } + + Value lower = zeroIndices[dim]; + Value upper = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(dim)).getResult(); + Value step = arith::ConstantIndexOp::create(rewriter, loc, 1).getResult(); + auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator}); + rewriter.setInsertionPointToStart(loop.getBody()); + resultIndices.push_back(loop.getInductionVar()); + Value updated = self(self, dim + 1, loop.getRegionIterArgs().front()); + resultIndices.pop_back(); + scf::YieldOp::create(rewriter, loc, updated); + rewriter.setInsertionPointAfter(loop); + return loop.getResult(0); + }; + + return buildLoopNest(buildLoopNest, 0, init); +} + } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp index ab03d75..908a9e7 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp @@ -141,4 +141,11 @@ mlir::Value broadcastToVector(mlir::Value scalarToBroadcast, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); +mlir::Value materializeContiguousTensorSlice(mlir::Value source, + mlir::RankedTensorType resultType, + llvm::ArrayRef offsets, + llvm::ArrayRef strides, + mlir::ConversionPatternRewriter& rewriter, + mlir::Location loc); + } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index c1f9a80..8f3b16a 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -111,6 +111,32 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr, return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr); } +static Value createConvWeightMatrix(Value w, + RankedTensorType wFlatType, + RankedTensorType wTransType, + ConversionPatternRewriter& rewriter, + Location loc) { + auto buildWeightMatrix = [&](Value weight) -> Value { + Value wFlat = tensor::CollapseShapeOp::create(rewriter, + loc, + wFlatType, + weight, + SmallVector { + {0}, + {1, 2, 3} + }); + return ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})).getResult(); + }; + + if (isCompileTimeComputable(w)) + return buildWeightMatrix(w); + + auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {wTransType}, {}, ValueRange {w}, [&](Value weight) { + spatial::SpatYieldOp::create(rewriter, loc, buildWeightMatrix(weight)); + }); + return computeOp.getResult(0); +} + static Value buildPackedBias(bool hasBias, Value gemmBias, Value biasMatrix, @@ -395,15 +421,7 @@ static Value lowerSingleConvGroup(Value x, // Prepare weight matrix W for crossbar storage: // W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut] - Value wFlat = tensor::CollapseShapeOp::create(rewriter, - loc, - wFlatType, - w, - SmallVector { - {0}, - {1, 2, 3} - }); - Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})); + Value wTrans = createConvWeightMatrix(w, wFlatType, wTransType, rewriter, loc); // Pass bias through directly; Gemm handles rank-1 C canonicalization. bool hasB = !isa(b.getDefiningOp()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index ce9973e..bf15091 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -73,38 +73,11 @@ static Value createIndexConstant(ConversionPatternRewriter& rewriter, int64_t va return getOrCreateHostIndexConstant(anchorOp, value, rewriter); } -static std::optional getConstantIndexValue(Value value) { - if (auto constantIndex = value.getDefiningOp()) - return constantIndex.value(); - - APInt constantValue; - if (matchPattern(value, m_ConstantInt(&constantValue))) - return constantValue.getSExtValue(); - - return std::nullopt; -} - static Value createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) { AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr); - - SmallVector operandConstants; - operandConstants.reserve(operands.size()); - for (Value operand : operands) { - std::optional constantValue = getConstantIndexValue(operand); - if (!constantValue) - return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); - operandConstants.push_back(rewriter.getIndexAttr(*constantValue)); - } - - SmallVector foldedResults; - if (succeeded(map.constantFold(operandConstants, foldedResults))) { - auto constantResult = dyn_cast(foldedResults.front()); - if (constantResult) - return createIndexConstant(rewriter, constantResult.getInt()); - } - - return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp); } static Value @@ -379,6 +352,233 @@ static spatial::SpatComputeBatch createVmmBatch(Value a, return batchOp; } +static Value createDynamicGemmBatchRow( + Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) { + if (numOutCols == 1) + return lane; + + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + return createAffineApply(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}); +} + +static Value createDynamicGemmBatchColumn( + Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) { + return modIndexByConstant(lane, numOutCols, rewriter, loc); +} + +static Value +extractDynamicGemmBColumn(Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) { + SmallVector offsets {rewriter.getIndexAttr(0), column}; + SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType()); + Value columnSlice = materializeContiguousTensorSlice(matrix, columnSliceType, offsets, strides, rewriter, loc); + SmallVector collapseReassociation {ReassociationIndices {0, 1}}; + auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType()); + Value collapsed = + tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, columnSlice, collapseReassociation).getResult(); + SmallVector expandReassociation {ReassociationIndices {0, 1}}; + return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult(); +} + +static Value extractTransposedBRow( + Value transposedB, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) { + SmallVector offsets {row, rewriter.getIndexAttr(0)}; + SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))}; + SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, transposedB, offsets, sizes, strides).getResult(); +} + +static Value extractDynamicGemmRowVector( + Value matrix, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) { + SmallVector offsets {row, rewriter.getIndexAttr(0)}; + SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))}; + SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, matrix, offsets, sizes, strides).getResult(); +} + +static FailureOr verifyDynamicGemmBiasType(RankedTensorType cType, RankedTensorType outType) { + if (!cType.hasStaticShape() || cType.getRank() > 2) + return failure(); + + if (cType.getRank() == 0) + return cType; + + int64_t numOutRows = outType.getDimSize(0); + int64_t numOutCols = outType.getDimSize(1); + if (cType.getRank() == 1) { + int64_t cols = cType.getDimSize(0); + if (cols == 1 || cols == numOutCols) + return cType; + return failure(); + } + + int64_t rows = cType.getDimSize(0); + int64_t cols = cType.getDimSize(1); + if ((rows == 1 || rows == numOutRows) && (cols == 1 || cols == numOutCols)) + return cType; + return failure(); +} + +static bool hasGemmBias(Value c) { + Operation* definingOp = c.getDefiningOp(); + return !definingOp || !isa(definingOp); +} + +static Value createScalarTensorConstant(RankedTensorType scalarType, + float value, + ConversionPatternRewriter& rewriter, + Location loc) { + auto elementType = scalarType.getElementType(); + auto scalarAttr = rewriter.getFloatAttr(elementType, value); + auto denseAttr = DenseElementsAttr::get(scalarType, scalarAttr); + return arith::ConstantOp::create(rewriter, loc, scalarType, denseAttr).getResult(); +} + +static Value createBroadcastedBiasScalar(Value bias, + RankedTensorType biasType, + Value row, + Value column, + RankedTensorType scalarType, + ConversionPatternRewriter& rewriter, + Location loc) { + SmallVector unitStrides(biasType.getRank(), rewriter.getIndexAttr(1)); + if (biasType.getRank() == 1) { + SmallVector offsets { + biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(column)}; + SmallVector sizes {rewriter.getIndexAttr(1)}; + auto vectorType = RankedTensorType::get({1}, scalarType.getElementType()); + Value vector = tensor::ExtractSliceOp::create(rewriter, loc, vectorType, bias, offsets, sizes, unitStrides) + .getResult(); + SmallVector reassociation {ReassociationIndices {0, 1}}; + return tensor::ExpandShapeOp::create(rewriter, loc, scalarType, vector, reassociation).getResult(); + } + + if (biasType.getRank() == 2) { + SmallVector offsets { + biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(row), + biasType.getDimSize(1) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(column)}; + SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + return tensor::ExtractSliceOp::create(rewriter, loc, scalarType, bias, offsets, sizes, unitStrides).getResult(); + } + + Value scalar = tensor::ExtractOp::create(rewriter, loc, bias, ValueRange {}).getResult(); + return tensor::SplatOp::create(rewriter, loc, scalarType, scalar).getResult(); +} + +static spatial::SpatComputeBatch createVvdmulBatch(Value a, + Value b, + RankedTensorType aType, + RankedTensorType bType, + RankedTensorType scalarPiecesType, + RankedTensorType outType, + bool bAlreadyTransposed, + ConversionPatternRewriter& rewriter, + Location loc) { + const int64_t numOutRows = outType.getDimSize(0); + const int64_t numOutCols = outType.getDimSize(1); + const int64_t reductionSize = aType.getDimSize(1); + const int64_t laneCount = numOutRows * numOutCols; + auto batchOp = spatial::SpatComputeBatch::create(rewriter, + loc, + TypeRange {scalarPiecesType}, + rewriter.getI32IntegerAttr(static_cast(laneCount)), + ValueRange {}, + ValueRange {a, b}); + + SmallVector blockArgTypes {rewriter.getIndexType(), aType, bType, scalarPiecesType}; + SmallVector blockArgLocs(blockArgTypes.size(), loc); + Block* body = + rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + rewriter.setInsertionPointToEnd(body); + + auto lane = batchOp.getLaneArgument(); + auto inputA = batchOp.getInputArgument(0); + auto inputB = batchOp.getInputArgument(1); + auto output = batchOp.getOutputArgument(0); + assert(lane && inputA && inputB && output && "malformed dynamic Gemm compute_batch body"); + + Value row = createDynamicGemmBatchRow(*lane, numOutCols, rewriter, loc); + Value column = createDynamicGemmBatchColumn(*lane, numOutCols, rewriter, loc); + + auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); + auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); + Value aVector = extractDynamicGemmRowVector(*inputA, row, vectorType, rewriter, loc); + Value bVector = bAlreadyTransposed + ? extractTransposedBRow(*inputB, column, vectorType, rewriter, loc) + : extractDynamicGemmBColumn(*inputB, column, vectorType, rewriter, loc); + Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult(); + + auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc); + rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); + SmallVector outputOffsets {*lane, rewriter.getIndexAttr(0)}; + SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + tensor::ParallelInsertSliceOp::create(rewriter, loc, scalar, *output, outputOffsets, scalarSizes, unitStrides); + + rewriter.setInsertionPointAfter(batchOp); + return batchOp; +} + +static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces, + Value bias, + RankedTensorType scalarPiecesType, + RankedTensorType biasType, + RankedTensorType outType, + float alpha, + float beta, + ConversionPatternRewriter& rewriter, + Location loc) { + const int64_t laneCount = scalarPiecesType.getDimSize(0); + const int64_t numOutCols = outType.getDimSize(1); + SmallVector inputs {scalarPieces}; + if (bias) + inputs.push_back(bias); + + return createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) { + Value pieces = blockArgs[0]; + Value biasArg = bias ? blockArgs[1] : Value(); + auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); + Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult(); + Value c0 = createIndexConstant(rewriter, 0); + Value c1 = createIndexConstant(rewriter, 1); + Value cLaneCount = createIndexConstant(rewriter, laneCount); + auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit}); + rewriter.setInsertionPointToStart(loop.getBody()); + + Value lane = loop.getInductionVar(); + Value outputAcc = loop.getRegionIterArgs().front(); + Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc); + Value column = createDynamicGemmBatchColumn(lane, numOutCols, rewriter, loc); + SmallVector scalarOffsets {lane, rewriter.getIndexAttr(0)}; + SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Value scalar = + tensor::ExtractSliceOp::create(rewriter, loc, scalarType, pieces, scalarOffsets, scalarSizes, unitStrides) + .getResult(); + if (alpha != 1.0f) { + Value alphaTensor = createScalarTensorConstant(scalarType, alpha, rewriter, loc); + scalar = spatial::SpatVMulOp::create(rewriter, loc, scalarType, scalar, alphaTensor).getResult(); + } + if (biasArg) { + Value biasScalar = createBroadcastedBiasScalar(biasArg, biasType, row, column, scalarType, rewriter, loc); + if (beta != 1.0f) { + Value betaTensor = createScalarTensorConstant(scalarType, beta, rewriter, loc); + biasScalar = spatial::SpatVMulOp::create(rewriter, loc, scalarType, biasScalar, betaTensor).getResult(); + } + scalar = spatial::SpatVAddOp::create(rewriter, loc, scalarType, scalar, biasScalar).getResult(); + } + SmallVector outputOffsets {row, column}; + Value outputNext = + tensor::InsertSliceOp::create(rewriter, loc, scalar, outputAcc, outputOffsets, scalarSizes, unitStrides) + .getResult(); + scf::YieldOp::create(rewriter, loc, outputNext); + + rewriter.setInsertionPointAfter(loop); + spatial::SpatYieldOp::create(rewriter, loc, loop.getResult(0)); + }); +} + static Value createPartialGroupOffset(Value hSlice, int64_t kSlice, int64_t numKSlices, @@ -570,9 +770,50 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, return failure(); } + const int64_t numOutRows = outType.getDimSize(0); + const int64_t numOutCols = outType.getDimSize(1); + const int64_t reductionSize = aType.getDimSize(1); + if (!isCompileTimeComputable(b)) { - gemmOp.emitOpError("requires Gemm input B to be statically computed from constants"); - return failure(); + bool hasC = hasGemmBias(c); + float alpha = gemmOpAdaptor.getAlpha().convertToFloat(); + float beta = gemmOpAdaptor.getBeta().convertToFloat(); + RankedTensorType biasType; + if (hasC) { + auto cType = dyn_cast(c.getType()); + if (!cType || !cType.hasStaticShape()) { + pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias"); + return failure(); + } + auto verifiedBiasType = verifyDynamicGemmBiasType(cType, outType); + if (failed(verifiedBiasType)) { + gemmOp.emitOpError("requires Gemm bias C to be broadcastable to the output shape"); + return failure(); + } + biasType = *verifiedBiasType; + } + + const int64_t expectedBRows = gemmOpAdaptor.getTransB() ? numOutCols : reductionSize; + const int64_t expectedBCols = gemmOpAdaptor.getTransB() ? reductionSize : numOutCols; + if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != expectedBRows + || bType.getDimSize(1) != expectedBCols) { + gemmOp.emitOpError("has inconsistent A, B, and output shapes"); + return failure(); + } + + const int64_t laneCount64 = numOutRows * numOutCols; + if (laneCount64 > std::numeric_limits::max()) { + gemmOp.emitOpError("requires Gemm dynamic batch lane count to fit in i32"); + return failure(); + } + + auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType()); + auto batchOp = createVvdmulBatch( + a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc); + auto outputCompute = createDynamicGemmOutputCompute( + batchOp.getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc); + rewriter.replaceOp(gemmOp, outputCompute.getResults()); + return success(); } auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc); @@ -590,9 +831,6 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, bType = cast(b.getType()); } - const int64_t numOutRows = outType.getDimSize(0); - const int64_t numOutCols = outType.getDimSize(1); - const int64_t reductionSize = aType.getDimSize(1); if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) { gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling"); return failure(); @@ -615,7 +853,7 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, aType = paddedAType; Value bias; - bool hasC = !isa(c.getDefiningOp()); + bool hasC = hasGemmBias(c); auto paddedOutType = RankedTensorType::get({numOutRows, paddedOutCols}, outType.getElementType()); if (hasC) { auto cType = dyn_cast(c.getType()); diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td index 4d35f5e..20fc586 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td @@ -21,6 +21,12 @@ def spatToPimVMM : Pat< (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; +def spatToPimVVDMul : Pat< + (SpatVVDMulOp:$srcOpRes $a, $b), + (PimVVDMulOp $a, $b, + (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) +>; + def spatToPimVVAdd : Pat< (SpatVAddOp:$srcOpRes $a, $b), (PimVVAddOp $a, $b, diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 981f473..ff2183f 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -524,6 +524,39 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto vvdmulOp = cast(op); + + auto lhsOpt = getBufferOrValue(rewriter, vvdmulOp.getLhs(), options, state); + if (failed(lhsOpt)) + return failure(); + + auto rhsOpt = getBufferOrValue(rewriter, vvdmulOp.getRhs(), options, state); + if (failed(rhsOpt)) + return failure(); + + auto outputBufferOpt = getBufferOrValue(rewriter, vvdmulOp.getOutputBuffer(), options, state); + if (failed(outputBufferOpt)) + return failure(); + + Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter); + Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter); + Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter); + + replaceOpWithNewBufferizedOp( + rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput); + return success(); + } +}; + template struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel, OpTy> { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { @@ -576,6 +609,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) { PimVVSubOp::attachInterface>(*ctx); PimVVMulOp::attachInterface>(*ctx); PimVVMaxOp::attachInterface>(*ctx); + PimVVDMulOp::attachInterface(*ctx); PimVAvgOp::attachInterface>(*ctx); PimVReluOp::attachInterface>(*ctx); diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index e5381ee..9669447 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -6,6 +6,7 @@ add_pim_library(SpatialOps SpatialOpsAsm.cpp SpatialOpsVerify.cpp SpatialOpsCanonicalization.cpp + ${PIM_SRC_ROOT}/Conversion/ONNXToSpatial/CompileTime.cpp Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index bd4024c..cdd3464 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -219,6 +219,25 @@ def SpatVMMOp : SpatOp<"wvmm", []> { }]; } +def SpatVVDMulOp : SpatOp<"vvdmul", []> { + let summary = "Dot product between two runtime vectors"; + + let arguments = (ins + SpatTensor:$lhs, + SpatTensor:$rhs + ); + + let results = (outs + SpatTensor:$output + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) + }]; +} + def SpatVAddOp : SpatOp<"vadd", []> { let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1"; diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index df421fc..115771f 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -249,6 +249,32 @@ LogicalResult SpatVMMOp::verify() { return success(); } +LogicalResult SpatVVDMulOp::verify() { + auto lhsType = dyn_cast(getLhs().getType()); + auto rhsType = dyn_cast(getRhs().getType()); + auto outputType = dyn_cast(getOutput().getType()); + if (!lhsType || !rhsType || !outputType) + return emitError("lhs, rhs, and output must be shaped values"); + if (!lhsType.hasRank() || !rhsType.hasRank() || !outputType.hasRank()) + return emitError("lhs, rhs, and output must have ranked types"); + + ArrayRef lhsShape = lhsType.getShape(); + ArrayRef rhsShape = rhsType.getShape(); + ArrayRef outputShape = outputType.getShape(); + if (lhsShape.size() != 2 || rhsShape.size() != 2 || outputShape.size() != 2) + return emitError("lhs, rhs, and output must have rank 2"); + if (lhsType.getElementType() != rhsType.getElementType() || lhsType.getElementType() != outputType.getElementType()) + return emitError("lhs, rhs, and output must have the same element type"); + if (lhsShape != rhsShape) + return emitError("lhs and rhs vector shapes must match"); + if (lhsShape[0] != 1 || lhsShape[1] <= 0) + return emitError("lhs and rhs vector shape must be (1, N) with N > 0"); + if (outputShape[0] != 1 || outputShape[1] != 1) + return emitError("output shape must be (1, 1)"); + + return success(); +} + LogicalResult SpatVAddOp::verify() { if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) return failure(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 32e5ffe..ecda21c 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -1019,27 +1019,6 @@ std::optional getIndexedIndexPattern(ArrayRef valu return std::nullopt; } -Value createAffineApplyOrConstant( - MaterializerState& state, Location loc, AffineMap map, ValueRange operands, Operation* anchor) { - SmallVector operandConstants; - operandConstants.reserve(operands.size()); - for (Value operand : operands) { - auto constantValue = getConstantIntValue(operand); - if (!constantValue) - return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult(); - operandConstants.push_back(state.rewriter.getIndexAttr(*constantValue)); - } - - SmallVector foldedResults; - if (succeeded(map.constantFold(operandConstants, foldedResults))) { - auto constantResult = dyn_cast(foldedResults.front()); - if (constantResult) - return createIndexConstant(state, anchor, constantResult.getInt()); - } - - return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult(); -} - Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) { MLIRContext* context = state.func.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); @@ -1054,7 +1033,7 @@ Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern } AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); - return createAffineApplyOrConstant(state, loc, map, ValueRange {index}, state.func); + return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {index}, state.func); } Value createIndexedIndexValue( @@ -3396,7 +3375,7 @@ Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targe int64_t laneCount = static_cast(targetClass.cpus.size()); AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1); - return createAffineApplyOrConstant(state, loc, map, ValueRange {slotIndex, *laneArg}, state.func); + return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}, state.func); } Value createBatchClassRunSourceLane(MaterializerState& state, diff --git a/validation/operations/README.md b/validation/operations/README.md index 2cce699..237d1c8 100644 --- a/validation/operations/README.md +++ b/validation/operations/README.md @@ -25,10 +25,11 @@ python3 validation/operations/gen_tests.py | Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input | | Grouped two groups | `conv/grouped_two_groups` | [1,4,4,4] | [1,4,4,4] | 1x1 | 1 | none | yes | group=2 channel partitioning | | Depthwise grouped | `conv/depthwise_grouped` | [1,3,4,4] | [1,3,2,2] | 3x3 | 1 | none | no | group=3, one input channel per group | +| Dynamic | `conv/dynamic` | [1,1,4,4] | [1,1,2,2] | 3x3 | 1 | none | no | Runtime input and weight | ## Gemm -| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes | +| Test | Directory | A (input) | B/W tensor | Output | transB | alpha | beta | Bias | Notes | |---------------|-------------------------|-----------|------------|----------|--------|-------|------|-------|------------------------------| | Simple | `gemm/simple` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Square weights | | Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N | @@ -38,12 +39,20 @@ python3 validation/operations/gen_tests.py | Small | `gemm/small` | [2,8] | [8,4] | [2,4] | no | 1 | 1 | no | Tiny matrices | | Large | `gemm/large` | [8,256] | [256,128] | [8,128] | no | 1 | 1 | no | Larger matrices | | transB + bias | `gemm/transB_with_bias` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | [64] | Combined | +| Dynamic | `gemm/dynamic` | [2,8] | [8,4] | [2,4] | no | 1 | 1 | no | Runtime matrix operands | +| Dynamic transB | `gemm/dynamic_transB` | [2,8] | [4,8] | [2,4] | yes | 1 | 1 | no | Runtime transpose handling | +| Dynamic bias | `gemm/dynamic_bias` | [2,8] | [8,4] | [2,4] | no | 1 | 1 | [4] | Runtime bias broadcast | +| Dynamic alpha | `gemm/dynamic_alpha` | [2,8] | [8,4] | [2,4] | no | 0.5 | 1 | no | Runtime alpha scaling | +| Dynamic beta | `gemm/dynamic_beta` | [2,8] | [8,4] | [2,4] | no | 1 | 2 | [4] | Runtime beta scaling | +| Dynamic bias + scale | `gemm/dynamic_bias_alpha_beta` | [2,8] | [8,4] | [2,4] | no | 0.5 | 2 | [4] | Runtime operands and bias | ## MatMul -| Test | Directory | A input | B weight | Output | Notes | +| Test | Directory | A input | B tensor | Output | Notes | |------------|---------------------|---------|----------|---------|------------------------------------| | Basic | `matmul/basic` | [2,3] | [3,4] | [2,4] | Direct 2D MatMul rewrite path | +| Left constant | `matmul/left_constant` | [2,3] | [3,4] | [2,4] | Constant LHS transpose rewrite path | +| Dynamic | `matmul/dynamic` | [2,3] | [3,4] | [2,4] | Runtime matrix operands | | Batched 3D | `matmul/batched_3d` | [2,2,3] | [2,3,4] | [2,2,4] | Matching-batch MatMul rewrite path | ## Gemv diff --git a/validation/operations/gen_tests.py b/validation/operations/gen_tests.py index 6ae2af4..0ba74b0 100644 --- a/validation/operations/gen_tests.py +++ b/validation/operations/gen_tests.py @@ -181,6 +181,18 @@ def conv_depthwise_grouped(): save_model(model, "conv/depthwise_grouped", "conv_depthwise_grouped.onnx") +def conv_dynamic(): + """Conv with input and weight both provided at runtime.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4]) + W = helper.make_tensor_value_info("W", TensorProto.FLOAT, [1, 1, 3, 3]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 2, 2]) + node = helper.make_node("Conv", ["X", "W"], ["Y"], + kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0]) + graph = helper.make_graph([node], "conv_dynamic", [X, W], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "conv/dynamic", "conv_dynamic.onnx") + + # --------------------------------------------------------------------------- # GEMM tests # --------------------------------------------------------------------------- @@ -291,6 +303,75 @@ def gemm_transB_with_bias(): save_model(model, "gemm/transB_with_bias", "gemm_transB_with_bias.onnx") +def gemm_dynamic(): + """GEMM with both matrix operands provided at runtime.""" + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8]) + B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + node = helper.make_node("Gemm", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "gemm_dynamic", [A, B], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gemm/dynamic", "gemm_dynamic.onnx") + + +def gemm_dynamic_transB(): + """GEMM with runtime matrix operands and transposed runtime B.""" + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8]) + B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [4, 8]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + node = helper.make_node("Gemm", ["A", "B"], ["Y"], transB=1) + graph = helper.make_graph([node], "gemm_dynamic_transB", [A, B], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gemm/dynamic_transB", "gemm_dynamic_transB.onnx") + + +def gemm_dynamic_bias(): + """GEMM with runtime matrix operands and runtime bias.""" + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8]) + B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4]) + C = helper.make_tensor_value_info("C", TensorProto.FLOAT, [4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + node = helper.make_node("Gemm", ["A", "B", "C"], ["Y"]) + graph = helper.make_graph([node], "gemm_dynamic_bias", [A, B, C], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gemm/dynamic_bias", "gemm_dynamic_bias.onnx") + + +def gemm_dynamic_alpha(): + """GEMM with runtime matrix operands and runtime alpha scaling.""" + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8]) + B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + node = helper.make_node("Gemm", ["A", "B"], ["Y"], alpha=0.5) + graph = helper.make_graph([node], "gemm_dynamic_alpha", [A, B], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gemm/dynamic_alpha", "gemm_dynamic_alpha.onnx") + + +def gemm_dynamic_beta(): + """GEMM with runtime matrix operands, runtime bias, and runtime beta scaling.""" + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8]) + B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4]) + C = helper.make_tensor_value_info("C", TensorProto.FLOAT, [4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + node = helper.make_node("Gemm", ["A", "B", "C"], ["Y"], beta=2.0) + graph = helper.make_graph([node], "gemm_dynamic_beta", [A, B, C], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gemm/dynamic_beta", "gemm_dynamic_beta.onnx") + + +def gemm_dynamic_bias_alpha_beta(): + """GEMM with runtime matrix operands, runtime bias, alpha, and beta.""" + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8]) + B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4]) + C = helper.make_tensor_value_info("C", TensorProto.FLOAT, [4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + node = helper.make_node("Gemm", ["A", "B", "C"], ["Y"], alpha=0.5, beta=2.0) + graph = helper.make_graph([node], "gemm_dynamic_bias_alpha_beta", [A, B, C], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "gemm/dynamic_bias_alpha_beta", "gemm_dynamic_bias_alpha_beta.onnx") + + # --------------------------------------------------------------------------- # MatMul tests # --------------------------------------------------------------------------- @@ -306,6 +387,28 @@ def matmul_basic(): save_model(model, "matmul/basic", "matmul_basic.onnx") +def matmul_left_constant(): + """Direct 2D MatMul with constant LHS.""" + A = numpy_helper.from_array(np.random.default_rng(69).uniform(-1, 1, (2, 3)).astype(np.float32), name="A") + B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + node = helper.make_node("MatMul", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "matmul_left_constant", [B], [Y], initializer=[A]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "matmul/left_constant", "matmul_left_constant.onnx") + + +def matmul_dynamic(): + """Direct 2D MatMul with both operands provided at runtime.""" + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3]) + B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + node = helper.make_node("MatMul", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "matmul_dynamic", [A, B], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "matmul/dynamic", "matmul_dynamic.onnx") + + def matmul_batched_3d(): """Batched 3D MatMul with matching batch dimensions.""" rng = np.random.default_rng(50) @@ -843,6 +946,12 @@ if __name__ == "__main__": gemm_small() gemm_large() gemm_transB_with_bias() + gemm_dynamic() + gemm_dynamic_transB() + gemm_dynamic_bias() + gemm_dynamic_alpha() + gemm_dynamic_beta() + gemm_dynamic_bias_alpha_beta() print("\nGenerating Conv tests:") conv_3x3_kernel() @@ -856,9 +965,12 @@ if __name__ == "__main__": conv_large_spatial() conv_grouped_two_groups() conv_depthwise_grouped() + conv_dynamic() print("\nGenerating MatMul tests:") matmul_basic() + matmul_left_constant() + matmul_dynamic() matmul_batched_3d() print("\nGenerating Pooling tests:")