Dynamic gemm/conv

This commit is contained in:
ilgeco
2026-05-28 18:00:14 +02:00
parent cbf7b235f1
commit 1ab489fe0a
17 changed files with 704 additions and 69 deletions
@@ -111,6 +111,32 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
}
static Value createConvWeightMatrix(Value w,
RankedTensorType wFlatType,
RankedTensorType wTransType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto buildWeightMatrix = [&](Value weight) -> Value {
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
wFlatType,
weight,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
return ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})).getResult();
};
if (isCompileTimeComputable(w))
return buildWeightMatrix(w);
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {wTransType}, {}, ValueRange {w}, [&](Value weight) {
spatial::SpatYieldOp::create(rewriter, loc, buildWeightMatrix(weight));
});
return computeOp.getResult(0);
}
static Value buildPackedBias(bool hasBias,
Value gemmBias,
Value biasMatrix,
@@ -395,15 +421,7 @@ static Value lowerSingleConvGroup(Value x,
// Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
wFlatType,
w,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
Value wTrans = createConvWeightMatrix(w, wFlatType, wTransType, rewriter, loc);
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());