Dynamic gemm/conv
This commit is contained in:
@@ -111,6 +111,32 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
||||
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
|
||||
}
|
||||
|
||||
static Value createConvWeightMatrix(Value w,
|
||||
RankedTensorType wFlatType,
|
||||
RankedTensorType wTransType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto buildWeightMatrix = [&](Value weight) -> Value {
|
||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
wFlatType,
|
||||
weight,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
return ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})).getResult();
|
||||
};
|
||||
|
||||
if (isCompileTimeComputable(w))
|
||||
return buildWeightMatrix(w);
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {wTransType}, {}, ValueRange {w}, [&](Value weight) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildWeightMatrix(weight));
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value buildPackedBias(bool hasBias,
|
||||
Value gemmBias,
|
||||
Value biasMatrix,
|
||||
@@ -395,15 +421,7 @@ static Value lowerSingleConvGroup(Value x,
|
||||
|
||||
// Prepare weight matrix W for crossbar storage:
|
||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
wFlatType,
|
||||
w,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||
Value wTrans = createConvWeightMatrix(w, wFlatType, wTransType, rewriter, loc);
|
||||
|
||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
|
||||
Reference in New Issue
Block a user