Dynamic gemm/conv
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
- `cmake --build ./build_release`
|
- `cmake --build ./build_release`
|
||||||
- `cmake --build ./build_debug`
|
- `cmake --build ./build_debug`
|
||||||
- Never use `ninja` directly: it bypasses cmake's configuration and invalidates the build cache.
|
- Never use `ninja` directly: it bypasses cmake's configuration and invalidates the build cache.
|
||||||
|
- Always tries the release version build first and ask before building with the debug version
|
||||||
|
|
||||||
# Code changes
|
# Code changes
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use crate::utility::AddressArg;
|
||||||
use std::{collections::HashMap, fmt::Debug};
|
use std::{collections::HashMap, fmt::Debug};
|
||||||
use anyhow::{Context, Result, ensure};
|
use anyhow::{Context, Result, ensure};
|
||||||
|
|
||||||
@@ -9,6 +10,7 @@ use crate::{
|
|||||||
|
|
||||||
pub mod crossbar;
|
pub mod crossbar;
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct CPU<'a> {
|
pub struct CPU<'a> {
|
||||||
cores: Box<[Core<'a>]>,
|
cores: Box<[Core<'a>]>,
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use anyhow::{Result,Context};
|
||||||
use std::{fmt::Debug, mem::transmute};
|
use std::{fmt::Debug, mem::transmute};
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
|
||||||
#include "ConstantUtils.hpp"
|
#include "ConstantUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
@@ -79,4 +81,24 @@ Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFo
|
|||||||
return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder);
|
return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value createAffineApplyOrFoldedConstant(
|
||||||
|
RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* anchorOp) {
|
||||||
|
SmallVector<Attribute> operandConstants;
|
||||||
|
operandConstants.reserve(operands.size());
|
||||||
|
for (Value operand : operands) {
|
||||||
|
APInt constantValue;
|
||||||
|
if (!matchPattern(operand, m_ConstantInt(&constantValue)))
|
||||||
|
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
||||||
|
operandConstants.push_back(rewriter.getIndexAttr(constantValue.getSExtValue()));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Attribute> foldedResults;
|
||||||
|
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
|
||||||
|
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
|
||||||
|
return getOrCreateHostIndexConstant(anchorOp, constantResult.getInt(), rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
@@ -29,4 +30,10 @@ mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value,
|
|||||||
|
|
||||||
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||||
|
|
||||||
|
mlir::Value createAffineApplyOrFoldedConstant(mlir::RewriterBase& rewriter,
|
||||||
|
mlir::Location loc,
|
||||||
|
mlir::AffineMap map,
|
||||||
|
mlir::ValueRange operands,
|
||||||
|
mlir::Operation* anchorOp);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
#include "ShapeTilingUtils.hpp"
|
#include "ShapeTilingUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
@@ -12,6 +16,72 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
if (auto attr = dyn_cast<Attribute>(result))
|
||||||
|
return arith::ConstantIndexOp::create(rewriter, loc, cast<IntegerAttr>(attr).getInt()).getResult();
|
||||||
|
return cast<Value>(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
APInt lhsConst;
|
||||||
|
if (matchPattern(lhs, m_ConstantInt(&lhsConst)) && lhsConst.isZero())
|
||||||
|
return rhs;
|
||||||
|
|
||||||
|
APInt rhsConst;
|
||||||
|
if (matchPattern(rhs, m_ConstantInt(&rhsConst)) && rhsConst.isZero())
|
||||||
|
return lhs;
|
||||||
|
|
||||||
|
return arith::AddIOp::create(rewriter, loc, lhs, rhs).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
APInt factorConst;
|
||||||
|
if (auto attr = dyn_cast<Attribute>(factor))
|
||||||
|
factorConst = cast<IntegerAttr>(attr).getValue();
|
||||||
|
else if (!matchPattern(cast<Value>(factor), m_ConstantInt(&factorConst)))
|
||||||
|
return arith::MulIOp::create(rewriter, loc, value, cast<Value>(factor)).getResult();
|
||||||
|
|
||||||
|
if (factorConst.isZero())
|
||||||
|
return arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||||
|
if (factorConst.isOne())
|
||||||
|
return value;
|
||||||
|
|
||||||
|
auto factorValue = arith::ConstantIndexOp::create(rewriter, loc, factorConst.getSExtValue()).getResult();
|
||||||
|
return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef<OpFoldResult> strides) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() || sourceType.getRank() != resultType.getRank())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
for (OpFoldResult stride : strides) {
|
||||||
|
APInt strideValue;
|
||||||
|
if (auto attr = dyn_cast<Attribute>(stride)) {
|
||||||
|
if (cast<IntegerAttr>(attr).getInt() != 1)
|
||||||
|
return false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!matchPattern(cast<Value>(stride), m_ConstantInt(&strideValue)) || !strideValue.isOne())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sizesAndShape = llvm::zip_equal(llvm::make_range(resultType.getShape().rbegin(), resultType.getShape().rend()),
|
||||||
|
llvm::make_range(sourceType.getShape().rbegin(), sourceType.getShape().rend()));
|
||||||
|
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||||
|
auto [size, dimension] = sizeAndShape;
|
||||||
|
return size != dimension;
|
||||||
|
});
|
||||||
|
if (firstDifferentSize == sizesAndShape.end())
|
||||||
|
return true;
|
||||||
|
|
||||||
|
++firstDifferentSize;
|
||||||
|
return std::all_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) {
|
||||||
|
auto [size, _dimension] = sizeAndShape;
|
||||||
|
return size == 1;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<Value> sliceTensor(
|
SmallVector<Value> sliceTensor(
|
||||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||||
@@ -123,4 +193,87 @@ Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatte
|
|||||||
return broadcastCompute.getResult(0);
|
return broadcastCompute.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value materializeContiguousTensorSlice(Value source,
|
||||||
|
RankedTensorType resultType,
|
||||||
|
ArrayRef<OpFoldResult> offsets,
|
||||||
|
ArrayRef<OpFoldResult> strides,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
assert(resultType.hasStaticShape() && "expected static result type");
|
||||||
|
size_t rank = static_cast<size_t>(resultType.getRank());
|
||||||
|
assert(offsets.size() == rank && "expected rank-matching offsets");
|
||||||
|
assert(strides.size() == rank && "expected rank-matching strides");
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
sizes.reserve(resultType.getRank());
|
||||||
|
for (int64_t size : resultType.getShape())
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(size));
|
||||||
|
|
||||||
|
if (isContiguousTensorSlice(source, resultType, strides))
|
||||||
|
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
|
||||||
|
|
||||||
|
if (resultType.getRank() == 0)
|
||||||
|
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult();
|
||||||
|
|
||||||
|
Value init = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult();
|
||||||
|
SmallVector<Value> zeroIndices(resultType.getRank());
|
||||||
|
for (Value& zeroIndex : zeroIndices)
|
||||||
|
zeroIndex = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||||
|
|
||||||
|
SmallVector<Value> resultIndices;
|
||||||
|
resultIndices.reserve(resultType.getRank());
|
||||||
|
|
||||||
|
auto buildLoopNest = [&](auto&& self, unsigned dim, Value accumulator) -> Value {
|
||||||
|
if (dim == resultType.getRank()) {
|
||||||
|
SmallVector<Value> sourceIndices;
|
||||||
|
sourceIndices.reserve(resultType.getRank());
|
||||||
|
for (unsigned idx = 0; idx < resultType.getRank(); ++idx) {
|
||||||
|
Value offsetValue = getIndexValue(offsets[idx], rewriter, loc);
|
||||||
|
Value scaledIndex = multiplyIndexValue(resultIndices[idx], strides[idx], rewriter, loc);
|
||||||
|
sourceIndices.push_back(addIndexValues(offsetValue, scaledIndex, rewriter, loc));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> sourceOffsets;
|
||||||
|
SmallVector<OpFoldResult> destinationOffsets;
|
||||||
|
SmallVector<OpFoldResult> unitSizes;
|
||||||
|
SmallVector<OpFoldResult> unitStrides;
|
||||||
|
sourceOffsets.reserve(resultType.getRank());
|
||||||
|
destinationOffsets.reserve(resultType.getRank());
|
||||||
|
unitSizes.reserve(resultType.getRank());
|
||||||
|
unitStrides.reserve(resultType.getRank());
|
||||||
|
for (Value index : sourceIndices)
|
||||||
|
sourceOffsets.push_back(index);
|
||||||
|
for (Value index : resultIndices)
|
||||||
|
destinationOffsets.push_back(index);
|
||||||
|
for (int64_t idx = 0; idx < resultType.getRank(); ++idx) {
|
||||||
|
unitSizes.push_back(rewriter.getIndexAttr(1));
|
||||||
|
unitStrides.push_back(rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto elementTensorType =
|
||||||
|
RankedTensorType::get(SmallVector<int64_t>(resultType.getRank(), 1), resultType.getElementType());
|
||||||
|
Value elementSlice =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, elementTensorType, source, sourceOffsets, unitSizes, unitStrides)
|
||||||
|
.getResult();
|
||||||
|
return tensor::InsertSliceOp::create(
|
||||||
|
rewriter, loc, elementSlice, accumulator, destinationOffsets, unitSizes, unitStrides)
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value lower = zeroIndices[dim];
|
||||||
|
Value upper = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(dim)).getResult();
|
||||||
|
Value step = arith::ConstantIndexOp::create(rewriter, loc, 1).getResult();
|
||||||
|
auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator});
|
||||||
|
rewriter.setInsertionPointToStart(loop.getBody());
|
||||||
|
resultIndices.push_back(loop.getInductionVar());
|
||||||
|
Value updated = self(self, dim + 1, loop.getRegionIterArgs().front());
|
||||||
|
resultIndices.pop_back();
|
||||||
|
scf::YieldOp::create(rewriter, loc, updated);
|
||||||
|
rewriter.setInsertionPointAfter(loop);
|
||||||
|
return loop.getResult(0);
|
||||||
|
};
|
||||||
|
|
||||||
|
return buildLoopNest(buildLoopNest, 0, init);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -141,4 +141,11 @@ mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
|
|||||||
mlir::ConversionPatternRewriter& rewriter,
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
mlir::Location loc);
|
mlir::Location loc);
|
||||||
|
|
||||||
|
mlir::Value materializeContiguousTensorSlice(mlir::Value source,
|
||||||
|
mlir::RankedTensorType resultType,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> offsets,
|
||||||
|
llvm::ArrayRef<mlir::OpFoldResult> strides,
|
||||||
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
|
mlir::Location loc);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -111,6 +111,32 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
|||||||
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
|
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value createConvWeightMatrix(Value w,
|
||||||
|
RankedTensorType wFlatType,
|
||||||
|
RankedTensorType wTransType,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto buildWeightMatrix = [&](Value weight) -> Value {
|
||||||
|
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
wFlatType,
|
||||||
|
weight,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2, 3}
|
||||||
|
});
|
||||||
|
return ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})).getResult();
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isCompileTimeComputable(w))
|
||||||
|
return buildWeightMatrix(w);
|
||||||
|
|
||||||
|
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {wTransType}, {}, ValueRange {w}, [&](Value weight) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, buildWeightMatrix(weight));
|
||||||
|
});
|
||||||
|
return computeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
static Value buildPackedBias(bool hasBias,
|
static Value buildPackedBias(bool hasBias,
|
||||||
Value gemmBias,
|
Value gemmBias,
|
||||||
Value biasMatrix,
|
Value biasMatrix,
|
||||||
@@ -395,15 +421,7 @@ static Value lowerSingleConvGroup(Value x,
|
|||||||
|
|
||||||
// Prepare weight matrix W for crossbar storage:
|
// Prepare weight matrix W for crossbar storage:
|
||||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
Value wTrans = createConvWeightMatrix(w, wFlatType, wTransType, rewriter, loc);
|
||||||
loc,
|
|
||||||
wFlatType,
|
|
||||||
w,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0},
|
|
||||||
{1, 2, 3}
|
|
||||||
});
|
|
||||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
|
||||||
|
|
||||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||||
|
|||||||
@@ -73,38 +73,11 @@ static Value createIndexConstant(ConversionPatternRewriter& rewriter, int64_t va
|
|||||||
return getOrCreateHostIndexConstant(anchorOp, value, rewriter);
|
return getOrCreateHostIndexConstant(anchorOp, value, rewriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::optional<int64_t> getConstantIndexValue(Value value) {
|
|
||||||
if (auto constantIndex = value.getDefiningOp<arith::ConstantIndexOp>())
|
|
||||||
return constantIndex.value();
|
|
||||||
|
|
||||||
APInt constantValue;
|
|
||||||
if (matchPattern(value, m_ConstantInt(&constantValue)))
|
|
||||||
return constantValue.getSExtValue();
|
|
||||||
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value
|
static Value
|
||||||
createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
|
createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) {
|
||||||
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
|
AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr);
|
||||||
|
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||||
SmallVector<Attribute> operandConstants;
|
return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp);
|
||||||
operandConstants.reserve(operands.size());
|
|
||||||
for (Value operand : operands) {
|
|
||||||
std::optional<int64_t> constantValue = getConstantIndexValue(operand);
|
|
||||||
if (!constantValue)
|
|
||||||
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
|
||||||
operandConstants.push_back(rewriter.getIndexAttr(*constantValue));
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Attribute> foldedResults;
|
|
||||||
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
|
|
||||||
auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front());
|
|
||||||
if (constantResult)
|
|
||||||
return createIndexConstant(rewriter, constantResult.getInt());
|
|
||||||
}
|
|
||||||
|
|
||||||
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value
|
static Value
|
||||||
@@ -379,6 +352,233 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
|
|||||||
return batchOp;
|
return batchOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value createDynamicGemmBatchRow(
|
||||||
|
Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
if (numOutCols == 1)
|
||||||
|
return lane;
|
||||||
|
|
||||||
|
MLIRContext* context = rewriter.getContext();
|
||||||
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
|
return createAffineApply(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createDynamicGemmBatchColumn(
|
||||||
|
Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
return modIndexByConstant(lane, numOutCols, rewriter, loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value
|
||||||
|
extractDynamicGemmBColumn(Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), column};
|
||||||
|
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType());
|
||||||
|
Value columnSlice = materializeContiguousTensorSlice(matrix, columnSliceType, offsets, strides, rewriter, loc);
|
||||||
|
SmallVector<ReassociationIndices> collapseReassociation {ReassociationIndices {0, 1}};
|
||||||
|
auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||||
|
Value collapsed =
|
||||||
|
tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, columnSlice, collapseReassociation).getResult();
|
||||||
|
SmallVector<ReassociationIndices> expandReassociation {ReassociationIndices {0, 1}};
|
||||||
|
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value extractTransposedBRow(
|
||||||
|
Value transposedB, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
||||||
|
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, transposedB, offsets, sizes, strides).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value extractDynamicGemmRowVector(
|
||||||
|
Value matrix, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
|
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
||||||
|
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, matrix, offsets, sizes, strides).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<RankedTensorType> verifyDynamicGemmBiasType(RankedTensorType cType, RankedTensorType outType) {
|
||||||
|
if (!cType.hasStaticShape() || cType.getRank() > 2)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (cType.getRank() == 0)
|
||||||
|
return cType;
|
||||||
|
|
||||||
|
int64_t numOutRows = outType.getDimSize(0);
|
||||||
|
int64_t numOutCols = outType.getDimSize(1);
|
||||||
|
if (cType.getRank() == 1) {
|
||||||
|
int64_t cols = cType.getDimSize(0);
|
||||||
|
if (cols == 1 || cols == numOutCols)
|
||||||
|
return cType;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t rows = cType.getDimSize(0);
|
||||||
|
int64_t cols = cType.getDimSize(1);
|
||||||
|
if ((rows == 1 || rows == numOutRows) && (cols == 1 || cols == numOutCols))
|
||||||
|
return cType;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool hasGemmBias(Value c) {
|
||||||
|
Operation* definingOp = c.getDefiningOp();
|
||||||
|
return !definingOp || !isa<ONNXNoneOp>(definingOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createScalarTensorConstant(RankedTensorType scalarType,
|
||||||
|
float value,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto elementType = scalarType.getElementType();
|
||||||
|
auto scalarAttr = rewriter.getFloatAttr(elementType, value);
|
||||||
|
auto denseAttr = DenseElementsAttr::get(scalarType, scalarAttr);
|
||||||
|
return arith::ConstantOp::create(rewriter, loc, scalarType, denseAttr).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createBroadcastedBiasScalar(Value bias,
|
||||||
|
RankedTensorType biasType,
|
||||||
|
Value row,
|
||||||
|
Value column,
|
||||||
|
RankedTensorType scalarType,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
SmallVector<OpFoldResult> unitStrides(biasType.getRank(), rewriter.getIndexAttr(1));
|
||||||
|
if (biasType.getRank() == 1) {
|
||||||
|
SmallVector<OpFoldResult> offsets {
|
||||||
|
biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(column)};
|
||||||
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1)};
|
||||||
|
auto vectorType = RankedTensorType::get({1}, scalarType.getElementType());
|
||||||
|
Value vector = tensor::ExtractSliceOp::create(rewriter, loc, vectorType, bias, offsets, sizes, unitStrides)
|
||||||
|
.getResult();
|
||||||
|
SmallVector<ReassociationIndices> reassociation {ReassociationIndices {0, 1}};
|
||||||
|
return tensor::ExpandShapeOp::create(rewriter, loc, scalarType, vector, reassociation).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (biasType.getRank() == 2) {
|
||||||
|
SmallVector<OpFoldResult> offsets {
|
||||||
|
biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(row),
|
||||||
|
biasType.getDimSize(1) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(column)};
|
||||||
|
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
return tensor::ExtractSliceOp::create(rewriter, loc, scalarType, bias, offsets, sizes, unitStrides).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value scalar = tensor::ExtractOp::create(rewriter, loc, bias, ValueRange {}).getResult();
|
||||||
|
return tensor::SplatOp::create(rewriter, loc, scalarType, scalar).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static spatial::SpatComputeBatch createVvdmulBatch(Value a,
|
||||||
|
Value b,
|
||||||
|
RankedTensorType aType,
|
||||||
|
RankedTensorType bType,
|
||||||
|
RankedTensorType scalarPiecesType,
|
||||||
|
RankedTensorType outType,
|
||||||
|
bool bAlreadyTransposed,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
const int64_t numOutRows = outType.getDimSize(0);
|
||||||
|
const int64_t numOutCols = outType.getDimSize(1);
|
||||||
|
const int64_t reductionSize = aType.getDimSize(1);
|
||||||
|
const int64_t laneCount = numOutRows * numOutCols;
|
||||||
|
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||||
|
loc,
|
||||||
|
TypeRange {scalarPiecesType},
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)),
|
||||||
|
ValueRange {},
|
||||||
|
ValueRange {a, b});
|
||||||
|
|
||||||
|
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), aType, bType, scalarPiecesType};
|
||||||
|
SmallVector<Location> blockArgLocs(blockArgTypes.size(), loc);
|
||||||
|
Block* body =
|
||||||
|
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
rewriter.setInsertionPointToEnd(body);
|
||||||
|
|
||||||
|
auto lane = batchOp.getLaneArgument();
|
||||||
|
auto inputA = batchOp.getInputArgument(0);
|
||||||
|
auto inputB = batchOp.getInputArgument(1);
|
||||||
|
auto output = batchOp.getOutputArgument(0);
|
||||||
|
assert(lane && inputA && inputB && output && "malformed dynamic Gemm compute_batch body");
|
||||||
|
|
||||||
|
Value row = createDynamicGemmBatchRow(*lane, numOutCols, rewriter, loc);
|
||||||
|
Value column = createDynamicGemmBatchColumn(*lane, numOutCols, rewriter, loc);
|
||||||
|
|
||||||
|
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||||
|
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||||
|
Value aVector = extractDynamicGemmRowVector(*inputA, row, vectorType, rewriter, loc);
|
||||||
|
Value bVector = bAlreadyTransposed
|
||||||
|
? extractTransposedBRow(*inputB, column, vectorType, rewriter, loc)
|
||||||
|
: extractDynamicGemmBColumn(*inputB, column, vectorType, rewriter, loc);
|
||||||
|
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||||
|
|
||||||
|
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||||
|
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||||
|
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
tensor::ParallelInsertSliceOp::create(rewriter, loc, scalar, *output, outputOffsets, scalarSizes, unitStrides);
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(batchOp);
|
||||||
|
return batchOp;
|
||||||
|
}
|
||||||
|
|
||||||
|
static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces,
|
||||||
|
Value bias,
|
||||||
|
RankedTensorType scalarPiecesType,
|
||||||
|
RankedTensorType biasType,
|
||||||
|
RankedTensorType outType,
|
||||||
|
float alpha,
|
||||||
|
float beta,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
const int64_t laneCount = scalarPiecesType.getDimSize(0);
|
||||||
|
const int64_t numOutCols = outType.getDimSize(1);
|
||||||
|
SmallVector<Value> inputs {scalarPieces};
|
||||||
|
if (bias)
|
||||||
|
inputs.push_back(bias);
|
||||||
|
|
||||||
|
return createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) {
|
||||||
|
Value pieces = blockArgs[0];
|
||||||
|
Value biasArg = bias ? blockArgs[1] : Value();
|
||||||
|
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||||
|
Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
|
||||||
|
Value c0 = createIndexConstant(rewriter, 0);
|
||||||
|
Value c1 = createIndexConstant(rewriter, 1);
|
||||||
|
Value cLaneCount = createIndexConstant(rewriter, laneCount);
|
||||||
|
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
|
||||||
|
rewriter.setInsertionPointToStart(loop.getBody());
|
||||||
|
|
||||||
|
Value lane = loop.getInductionVar();
|
||||||
|
Value outputAcc = loop.getRegionIterArgs().front();
|
||||||
|
Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc);
|
||||||
|
Value column = createDynamicGemmBatchColumn(lane, numOutCols, rewriter, loc);
|
||||||
|
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
Value scalar =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, scalarType, pieces, scalarOffsets, scalarSizes, unitStrides)
|
||||||
|
.getResult();
|
||||||
|
if (alpha != 1.0f) {
|
||||||
|
Value alphaTensor = createScalarTensorConstant(scalarType, alpha, rewriter, loc);
|
||||||
|
scalar = spatial::SpatVMulOp::create(rewriter, loc, scalarType, scalar, alphaTensor).getResult();
|
||||||
|
}
|
||||||
|
if (biasArg) {
|
||||||
|
Value biasScalar = createBroadcastedBiasScalar(biasArg, biasType, row, column, scalarType, rewriter, loc);
|
||||||
|
if (beta != 1.0f) {
|
||||||
|
Value betaTensor = createScalarTensorConstant(scalarType, beta, rewriter, loc);
|
||||||
|
biasScalar = spatial::SpatVMulOp::create(rewriter, loc, scalarType, biasScalar, betaTensor).getResult();
|
||||||
|
}
|
||||||
|
scalar = spatial::SpatVAddOp::create(rewriter, loc, scalarType, scalar, biasScalar).getResult();
|
||||||
|
}
|
||||||
|
SmallVector<OpFoldResult> outputOffsets {row, column};
|
||||||
|
Value outputNext =
|
||||||
|
tensor::InsertSliceOp::create(rewriter, loc, scalar, outputAcc, outputOffsets, scalarSizes, unitStrides)
|
||||||
|
.getResult();
|
||||||
|
scf::YieldOp::create(rewriter, loc, outputNext);
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(loop);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, loop.getResult(0));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
static Value createPartialGroupOffset(Value hSlice,
|
static Value createPartialGroupOffset(Value hSlice,
|
||||||
int64_t kSlice,
|
int64_t kSlice,
|
||||||
int64_t numKSlices,
|
int64_t numKSlices,
|
||||||
@@ -570,10 +770,51 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int64_t numOutRows = outType.getDimSize(0);
|
||||||
|
const int64_t numOutCols = outType.getDimSize(1);
|
||||||
|
const int64_t reductionSize = aType.getDimSize(1);
|
||||||
|
|
||||||
if (!isCompileTimeComputable(b)) {
|
if (!isCompileTimeComputable(b)) {
|
||||||
gemmOp.emitOpError("requires Gemm input B to be statically computed from constants");
|
bool hasC = hasGemmBias(c);
|
||||||
|
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
|
||||||
|
float beta = gemmOpAdaptor.getBeta().convertToFloat();
|
||||||
|
RankedTensorType biasType;
|
||||||
|
if (hasC) {
|
||||||
|
auto cType = dyn_cast<RankedTensorType>(c.getType());
|
||||||
|
if (!cType || !cType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
auto verifiedBiasType = verifyDynamicGemmBiasType(cType, outType);
|
||||||
|
if (failed(verifiedBiasType)) {
|
||||||
|
gemmOp.emitOpError("requires Gemm bias C to be broadcastable to the output shape");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
biasType = *verifiedBiasType;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t expectedBRows = gemmOpAdaptor.getTransB() ? numOutCols : reductionSize;
|
||||||
|
const int64_t expectedBCols = gemmOpAdaptor.getTransB() ? reductionSize : numOutCols;
|
||||||
|
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != expectedBRows
|
||||||
|
|| bType.getDimSize(1) != expectedBCols) {
|
||||||
|
gemmOp.emitOpError("has inconsistent A, B, and output shapes");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t laneCount64 = numOutRows * numOutCols;
|
||||||
|
if (laneCount64 > std::numeric_limits<int32_t>::max()) {
|
||||||
|
gemmOp.emitOpError("requires Gemm dynamic batch lane count to fit in i32");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
|
||||||
|
auto batchOp = createVvdmulBatch(
|
||||||
|
a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
|
||||||
|
auto outputCompute = createDynamicGemmOutputCompute(
|
||||||
|
batchOp.getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc);
|
||||||
|
rewriter.replaceOp(gemmOp, outputCompute.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
|
||||||
if (failed(scaledB)) {
|
if (failed(scaledB)) {
|
||||||
@@ -590,9 +831,6 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
bType = cast<RankedTensorType>(b.getType());
|
bType = cast<RankedTensorType>(b.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t numOutRows = outType.getDimSize(0);
|
|
||||||
const int64_t numOutCols = outType.getDimSize(1);
|
|
||||||
const int64_t reductionSize = aType.getDimSize(1);
|
|
||||||
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) {
|
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) {
|
||||||
gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling");
|
gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling");
|
||||||
return failure();
|
return failure();
|
||||||
@@ -615,7 +853,7 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
aType = paddedAType;
|
aType = paddedAType;
|
||||||
|
|
||||||
Value bias;
|
Value bias;
|
||||||
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
|
bool hasC = hasGemmBias(c);
|
||||||
auto paddedOutType = RankedTensorType::get({numOutRows, paddedOutCols}, outType.getElementType());
|
auto paddedOutType = RankedTensorType::get({numOutRows, paddedOutCols}, outType.getElementType());
|
||||||
if (hasC) {
|
if (hasC) {
|
||||||
auto cType = dyn_cast<RankedTensorType>(c.getType());
|
auto cType = dyn_cast<RankedTensorType>(c.getType());
|
||||||
|
|||||||
@@ -21,6 +21,12 @@ def spatToPimVMM : Pat<
|
|||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
def spatToPimVVDMul : Pat<
|
||||||
|
(SpatVVDMulOp:$srcOpRes $a, $b),
|
||||||
|
(PimVVDMulOp $a, $b,
|
||||||
|
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||||
|
>;
|
||||||
|
|
||||||
def spatToPimVVAdd : Pat<
|
def spatToPimVVAdd : Pat<
|
||||||
(SpatVAddOp:$srcOpRes $a, $b),
|
(SpatVAddOp:$srcOpRes $a, $b),
|
||||||
(PimVVAddOp $a, $b,
|
(PimVVAddOp $a, $b,
|
||||||
|
|||||||
@@ -524,6 +524,39 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModel<VVDMulOpInterface, PimVVDMulOp> {
|
||||||
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
|
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult bufferize(Operation* op,
|
||||||
|
RewriterBase& rewriter,
|
||||||
|
const BufferizationOptions& options,
|
||||||
|
BufferizationState& state) const {
|
||||||
|
auto vvdmulOp = cast<PimVVDMulOp>(op);
|
||||||
|
|
||||||
|
auto lhsOpt = getBufferOrValue(rewriter, vvdmulOp.getLhs(), options, state);
|
||||||
|
if (failed(lhsOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto rhsOpt = getBufferOrValue(rewriter, vvdmulOp.getRhs(), options, state);
|
||||||
|
if (failed(rhsOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto outputBufferOpt = getBufferOrValue(rewriter, vvdmulOp.getOutputBuffer(), options, state);
|
||||||
|
if (failed(outputBufferOpt))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
|
||||||
|
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
|
||||||
|
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<PimVVDMulOp>(
|
||||||
|
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpInterface<OpTy>, OpTy> {
|
struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpInterface<OpTy>, OpTy> {
|
||||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||||
@@ -576,6 +609,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
|||||||
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
|
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
|
||||||
PimVVMulOp::attachInterface<BinaryDstOpInterface<PimVVMulOp>>(*ctx);
|
PimVVMulOp::attachInterface<BinaryDstOpInterface<PimVVMulOp>>(*ctx);
|
||||||
PimVVMaxOp::attachInterface<BinaryDstOpInterface<PimVVMaxOp>>(*ctx);
|
PimVVMaxOp::attachInterface<BinaryDstOpInterface<PimVVMaxOp>>(*ctx);
|
||||||
|
PimVVDMulOp::attachInterface<VVDMulOpInterface>(*ctx);
|
||||||
|
|
||||||
PimVAvgOp::attachInterface<UnaryDstOpInterface<PimVAvgOp>>(*ctx);
|
PimVAvgOp::attachInterface<UnaryDstOpInterface<PimVAvgOp>>(*ctx);
|
||||||
PimVReluOp::attachInterface<UnaryDstOpInterface<PimVReluOp>>(*ctx);
|
PimVReluOp::attachInterface<UnaryDstOpInterface<PimVReluOp>>(*ctx);
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ add_pim_library(SpatialOps
|
|||||||
SpatialOpsAsm.cpp
|
SpatialOpsAsm.cpp
|
||||||
SpatialOpsVerify.cpp
|
SpatialOpsVerify.cpp
|
||||||
SpatialOpsCanonicalization.cpp
|
SpatialOpsCanonicalization.cpp
|
||||||
|
${PIM_SRC_ROOT}/Conversion/ONNXToSpatial/CompileTime.cpp
|
||||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||||
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
||||||
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
||||||
|
|||||||
@@ -219,6 +219,25 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
|
|||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def SpatVVDMulOp : SpatOp<"vvdmul", []> {
|
||||||
|
let summary = "Dot product between two runtime vectors";
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
SpatTensor:$lhs,
|
||||||
|
SpatTensor:$rhs
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
SpatTensor:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let hasVerifier = 1;
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def SpatVAddOp : SpatOp<"vadd", []> {
|
def SpatVAddOp : SpatOp<"vadd", []> {
|
||||||
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
|
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
|
||||||
|
|
||||||
|
|||||||
@@ -249,6 +249,32 @@ LogicalResult SpatVMMOp::verify() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult SpatVVDMulOp::verify() {
|
||||||
|
auto lhsType = dyn_cast<ShapedType>(getLhs().getType());
|
||||||
|
auto rhsType = dyn_cast<ShapedType>(getRhs().getType());
|
||||||
|
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
|
||||||
|
if (!lhsType || !rhsType || !outputType)
|
||||||
|
return emitError("lhs, rhs, and output must be shaped values");
|
||||||
|
if (!lhsType.hasRank() || !rhsType.hasRank() || !outputType.hasRank())
|
||||||
|
return emitError("lhs, rhs, and output must have ranked types");
|
||||||
|
|
||||||
|
ArrayRef<int64_t> lhsShape = lhsType.getShape();
|
||||||
|
ArrayRef<int64_t> rhsShape = rhsType.getShape();
|
||||||
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||||
|
if (lhsShape.size() != 2 || rhsShape.size() != 2 || outputShape.size() != 2)
|
||||||
|
return emitError("lhs, rhs, and output must have rank 2");
|
||||||
|
if (lhsType.getElementType() != rhsType.getElementType() || lhsType.getElementType() != outputType.getElementType())
|
||||||
|
return emitError("lhs, rhs, and output must have the same element type");
|
||||||
|
if (lhsShape != rhsShape)
|
||||||
|
return emitError("lhs and rhs vector shapes must match");
|
||||||
|
if (lhsShape[0] != 1 || lhsShape[1] <= 0)
|
||||||
|
return emitError("lhs and rhs vector shape must be (1, N) with N > 0");
|
||||||
|
if (outputShape[0] != 1 || outputShape[1] != 1)
|
||||||
|
return emitError("output shape must be (1, 1)");
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult SpatVAddOp::verify() {
|
LogicalResult SpatVAddOp::verify() {
|
||||||
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -1019,27 +1019,6 @@ std::optional<IndexedIndexPattern> getIndexedIndexPattern(ArrayRef<int64_t> valu
|
|||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
Value createAffineApplyOrConstant(
|
|
||||||
MaterializerState& state, Location loc, AffineMap map, ValueRange operands, Operation* anchor) {
|
|
||||||
SmallVector<Attribute> operandConstants;
|
|
||||||
operandConstants.reserve(operands.size());
|
|
||||||
for (Value operand : operands) {
|
|
||||||
auto constantValue = getConstantIntValue(operand);
|
|
||||||
if (!constantValue)
|
|
||||||
return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult();
|
|
||||||
operandConstants.push_back(state.rewriter.getIndexAttr(*constantValue));
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Attribute> foldedResults;
|
|
||||||
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
|
|
||||||
auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front());
|
|
||||||
if (constantResult)
|
|
||||||
return createIndexConstant(state, anchor, constantResult.getInt());
|
|
||||||
}
|
|
||||||
|
|
||||||
return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult();
|
|
||||||
}
|
|
||||||
|
|
||||||
Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) {
|
Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) {
|
||||||
MLIRContext* context = state.func.getContext();
|
MLIRContext* context = state.func.getContext();
|
||||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||||
@@ -1054,7 +1033,7 @@ Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern
|
|||||||
}
|
}
|
||||||
|
|
||||||
AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
|
AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
|
||||||
return createAffineApplyOrConstant(state, loc, map, ValueRange {index}, state.func);
|
return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {index}, state.func);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value createIndexedIndexValue(
|
Value createIndexedIndexValue(
|
||||||
@@ -3396,7 +3375,7 @@ Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targe
|
|||||||
|
|
||||||
int64_t laneCount = static_cast<int64_t>(targetClass.cpus.size());
|
int64_t laneCount = static_cast<int64_t>(targetClass.cpus.size());
|
||||||
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1);
|
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1);
|
||||||
return createAffineApplyOrConstant(state, loc, map, ValueRange {slotIndex, *laneArg}, state.func);
|
return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}, state.func);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value createBatchClassRunSourceLane(MaterializerState& state,
|
Value createBatchClassRunSourceLane(MaterializerState& state,
|
||||||
|
|||||||
@@ -25,10 +25,11 @@ python3 validation/operations/gen_tests.py
|
|||||||
| Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input |
|
| Large spatial | `conv/large_spatial` | [1,1,8,8] | [1,1,6,6] | 3x3 | 1 | none | no | Larger spatial input |
|
||||||
| Grouped two groups | `conv/grouped_two_groups` | [1,4,4,4] | [1,4,4,4] | 1x1 | 1 | none | yes | group=2 channel partitioning |
|
| Grouped two groups | `conv/grouped_two_groups` | [1,4,4,4] | [1,4,4,4] | 1x1 | 1 | none | yes | group=2 channel partitioning |
|
||||||
| Depthwise grouped | `conv/depthwise_grouped` | [1,3,4,4] | [1,3,2,2] | 3x3 | 1 | none | no | group=3, one input channel per group |
|
| Depthwise grouped | `conv/depthwise_grouped` | [1,3,4,4] | [1,3,2,2] | 3x3 | 1 | none | no | group=3, one input channel per group |
|
||||||
|
| Dynamic | `conv/dynamic` | [1,1,4,4] | [1,1,2,2] | 3x3 | 1 | none | no | Runtime input and weight |
|
||||||
|
|
||||||
## Gemm
|
## Gemm
|
||||||
|
|
||||||
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|
| Test | Directory | A (input) | B/W tensor | Output | transB | alpha | beta | Bias | Notes |
|
||||||
|---------------|-------------------------|-----------|------------|----------|--------|-------|------|-------|------------------------------|
|
|---------------|-------------------------|-----------|------------|----------|--------|-------|------|-------|------------------------------|
|
||||||
| Simple | `gemm/simple` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Square weights |
|
| Simple | `gemm/simple` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Square weights |
|
||||||
| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N |
|
| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N |
|
||||||
@@ -38,12 +39,20 @@ python3 validation/operations/gen_tests.py
|
|||||||
| Small | `gemm/small` | [2,8] | [8,4] | [2,4] | no | 1 | 1 | no | Tiny matrices |
|
| Small | `gemm/small` | [2,8] | [8,4] | [2,4] | no | 1 | 1 | no | Tiny matrices |
|
||||||
| Large | `gemm/large` | [8,256] | [256,128] | [8,128] | no | 1 | 1 | no | Larger matrices |
|
| Large | `gemm/large` | [8,256] | [256,128] | [8,128] | no | 1 | 1 | no | Larger matrices |
|
||||||
| transB + bias | `gemm/transB_with_bias` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | [64] | Combined |
|
| transB + bias | `gemm/transB_with_bias` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | [64] | Combined |
|
||||||
|
| Dynamic | `gemm/dynamic` | [2,8] | [8,4] | [2,4] | no | 1 | 1 | no | Runtime matrix operands |
|
||||||
|
| Dynamic transB | `gemm/dynamic_transB` | [2,8] | [4,8] | [2,4] | yes | 1 | 1 | no | Runtime transpose handling |
|
||||||
|
| Dynamic bias | `gemm/dynamic_bias` | [2,8] | [8,4] | [2,4] | no | 1 | 1 | [4] | Runtime bias broadcast |
|
||||||
|
| Dynamic alpha | `gemm/dynamic_alpha` | [2,8] | [8,4] | [2,4] | no | 0.5 | 1 | no | Runtime alpha scaling |
|
||||||
|
| Dynamic beta | `gemm/dynamic_beta` | [2,8] | [8,4] | [2,4] | no | 1 | 2 | [4] | Runtime beta scaling |
|
||||||
|
| Dynamic bias + scale | `gemm/dynamic_bias_alpha_beta` | [2,8] | [8,4] | [2,4] | no | 0.5 | 2 | [4] | Runtime operands and bias |
|
||||||
|
|
||||||
## MatMul
|
## MatMul
|
||||||
|
|
||||||
| Test | Directory | A input | B weight | Output | Notes |
|
| Test | Directory | A input | B tensor | Output | Notes |
|
||||||
|------------|---------------------|---------|----------|---------|------------------------------------|
|
|------------|---------------------|---------|----------|---------|------------------------------------|
|
||||||
| Basic | `matmul/basic` | [2,3] | [3,4] | [2,4] | Direct 2D MatMul rewrite path |
|
| Basic | `matmul/basic` | [2,3] | [3,4] | [2,4] | Direct 2D MatMul rewrite path |
|
||||||
|
| Left constant | `matmul/left_constant` | [2,3] | [3,4] | [2,4] | Constant LHS transpose rewrite path |
|
||||||
|
| Dynamic | `matmul/dynamic` | [2,3] | [3,4] | [2,4] | Runtime matrix operands |
|
||||||
| Batched 3D | `matmul/batched_3d` | [2,2,3] | [2,3,4] | [2,2,4] | Matching-batch MatMul rewrite path |
|
| Batched 3D | `matmul/batched_3d` | [2,2,3] | [2,3,4] | [2,2,4] | Matching-batch MatMul rewrite path |
|
||||||
|
|
||||||
## Gemv
|
## Gemv
|
||||||
|
|||||||
@@ -181,6 +181,18 @@ def conv_depthwise_grouped():
|
|||||||
save_model(model, "conv/depthwise_grouped", "conv_depthwise_grouped.onnx")
|
save_model(model, "conv/depthwise_grouped", "conv_depthwise_grouped.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def conv_dynamic():
|
||||||
|
"""Conv with input and weight both provided at runtime."""
|
||||||
|
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
|
||||||
|
W = helper.make_tensor_value_info("W", TensorProto.FLOAT, [1, 1, 3, 3])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 2, 2])
|
||||||
|
node = helper.make_node("Conv", ["X", "W"], ["Y"],
|
||||||
|
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||||
|
graph = helper.make_graph([node], "conv_dynamic", [X, W], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "conv/dynamic", "conv_dynamic.onnx")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# GEMM tests
|
# GEMM tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -291,6 +303,75 @@ def gemm_transB_with_bias():
|
|||||||
save_model(model, "gemm/transB_with_bias", "gemm_transB_with_bias.onnx")
|
save_model(model, "gemm/transB_with_bias", "gemm_transB_with_bias.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_dynamic():
|
||||||
|
"""GEMM with both matrix operands provided at runtime."""
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8])
|
||||||
|
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
|
||||||
|
node = helper.make_node("Gemm", ["A", "B"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "gemm_dynamic", [A, B], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "gemm/dynamic", "gemm_dynamic.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_dynamic_transB():
|
||||||
|
"""GEMM with runtime matrix operands and transposed runtime B."""
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8])
|
||||||
|
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [4, 8])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
|
||||||
|
node = helper.make_node("Gemm", ["A", "B"], ["Y"], transB=1)
|
||||||
|
graph = helper.make_graph([node], "gemm_dynamic_transB", [A, B], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "gemm/dynamic_transB", "gemm_dynamic_transB.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_dynamic_bias():
|
||||||
|
"""GEMM with runtime matrix operands and runtime bias."""
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8])
|
||||||
|
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4])
|
||||||
|
C = helper.make_tensor_value_info("C", TensorProto.FLOAT, [4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
|
||||||
|
node = helper.make_node("Gemm", ["A", "B", "C"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "gemm_dynamic_bias", [A, B, C], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "gemm/dynamic_bias", "gemm_dynamic_bias.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_dynamic_alpha():
|
||||||
|
"""GEMM with runtime matrix operands and runtime alpha scaling."""
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8])
|
||||||
|
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
|
||||||
|
node = helper.make_node("Gemm", ["A", "B"], ["Y"], alpha=0.5)
|
||||||
|
graph = helper.make_graph([node], "gemm_dynamic_alpha", [A, B], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "gemm/dynamic_alpha", "gemm_dynamic_alpha.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_dynamic_beta():
|
||||||
|
"""GEMM with runtime matrix operands, runtime bias, and runtime beta scaling."""
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8])
|
||||||
|
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4])
|
||||||
|
C = helper.make_tensor_value_info("C", TensorProto.FLOAT, [4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
|
||||||
|
node = helper.make_node("Gemm", ["A", "B", "C"], ["Y"], beta=2.0)
|
||||||
|
graph = helper.make_graph([node], "gemm_dynamic_beta", [A, B, C], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "gemm/dynamic_beta", "gemm_dynamic_beta.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def gemm_dynamic_bias_alpha_beta():
|
||||||
|
"""GEMM with runtime matrix operands, runtime bias, alpha, and beta."""
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8])
|
||||||
|
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4])
|
||||||
|
C = helper.make_tensor_value_info("C", TensorProto.FLOAT, [4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
|
||||||
|
node = helper.make_node("Gemm", ["A", "B", "C"], ["Y"], alpha=0.5, beta=2.0)
|
||||||
|
graph = helper.make_graph([node], "gemm_dynamic_bias_alpha_beta", [A, B, C], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "gemm/dynamic_bias_alpha_beta", "gemm_dynamic_bias_alpha_beta.onnx")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# MatMul tests
|
# MatMul tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -306,6 +387,28 @@ def matmul_basic():
|
|||||||
save_model(model, "matmul/basic", "matmul_basic.onnx")
|
save_model(model, "matmul/basic", "matmul_basic.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def matmul_left_constant():
|
||||||
|
"""Direct 2D MatMul with constant LHS."""
|
||||||
|
A = numpy_helper.from_array(np.random.default_rng(69).uniform(-1, 1, (2, 3)).astype(np.float32), name="A")
|
||||||
|
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
|
||||||
|
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "matmul_left_constant", [B], [Y], initializer=[A])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "matmul/left_constant", "matmul_left_constant.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
def matmul_dynamic():
|
||||||
|
"""Direct 2D MatMul with both operands provided at runtime."""
|
||||||
|
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3])
|
||||||
|
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 4])
|
||||||
|
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
|
||||||
|
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
|
||||||
|
graph = helper.make_graph([node], "matmul_dynamic", [A, B], [Y])
|
||||||
|
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||||
|
save_model(model, "matmul/dynamic", "matmul_dynamic.onnx")
|
||||||
|
|
||||||
|
|
||||||
def matmul_batched_3d():
|
def matmul_batched_3d():
|
||||||
"""Batched 3D MatMul with matching batch dimensions."""
|
"""Batched 3D MatMul with matching batch dimensions."""
|
||||||
rng = np.random.default_rng(50)
|
rng = np.random.default_rng(50)
|
||||||
@@ -843,6 +946,12 @@ if __name__ == "__main__":
|
|||||||
gemm_small()
|
gemm_small()
|
||||||
gemm_large()
|
gemm_large()
|
||||||
gemm_transB_with_bias()
|
gemm_transB_with_bias()
|
||||||
|
gemm_dynamic()
|
||||||
|
gemm_dynamic_transB()
|
||||||
|
gemm_dynamic_bias()
|
||||||
|
gemm_dynamic_alpha()
|
||||||
|
gemm_dynamic_beta()
|
||||||
|
gemm_dynamic_bias_alpha_beta()
|
||||||
|
|
||||||
print("\nGenerating Conv tests:")
|
print("\nGenerating Conv tests:")
|
||||||
conv_3x3_kernel()
|
conv_3x3_kernel()
|
||||||
@@ -856,9 +965,12 @@ if __name__ == "__main__":
|
|||||||
conv_large_spatial()
|
conv_large_spatial()
|
||||||
conv_grouped_two_groups()
|
conv_grouped_two_groups()
|
||||||
conv_depthwise_grouped()
|
conv_depthwise_grouped()
|
||||||
|
conv_dynamic()
|
||||||
|
|
||||||
print("\nGenerating MatMul tests:")
|
print("\nGenerating MatMul tests:")
|
||||||
matmul_basic()
|
matmul_basic()
|
||||||
|
matmul_left_constant()
|
||||||
|
matmul_dynamic()
|
||||||
matmul_batched_3d()
|
matmul_batched_3d()
|
||||||
|
|
||||||
print("\nGenerating Pooling tests:")
|
print("\nGenerating Pooling tests:")
|
||||||
|
|||||||
Reference in New Issue
Block a user