Dynamic gemm/conv

This commit is contained in:
ilgeco
2026-05-28 18:00:14 +02:00
parent cbf7b235f1
commit 1ab489fe0a
17 changed files with 704 additions and 69 deletions
+1
View File
@@ -6,6 +6,7 @@ add_pim_library(SpatialOps
SpatialOpsAsm.cpp
SpatialOpsVerify.cpp
SpatialOpsCanonicalization.cpp
${PIM_SRC_ROOT}/Conversion/ONNXToSpatial/CompileTime.cpp
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
+19
View File
@@ -219,6 +219,25 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
}];
}
def SpatVVDMulOp : SpatOp<"vvdmul", []> {
let summary = "Dot product between two runtime vectors";
let arguments = (ins
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatVAddOp : SpatOp<"vadd", []> {
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
@@ -249,6 +249,32 @@ LogicalResult SpatVMMOp::verify() {
return success();
}
LogicalResult SpatVVDMulOp::verify() {
auto lhsType = dyn_cast<ShapedType>(getLhs().getType());
auto rhsType = dyn_cast<ShapedType>(getRhs().getType());
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
if (!lhsType || !rhsType || !outputType)
return emitError("lhs, rhs, and output must be shaped values");
if (!lhsType.hasRank() || !rhsType.hasRank() || !outputType.hasRank())
return emitError("lhs, rhs, and output must have ranked types");
ArrayRef<int64_t> lhsShape = lhsType.getShape();
ArrayRef<int64_t> rhsShape = rhsType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
if (lhsShape.size() != 2 || rhsShape.size() != 2 || outputShape.size() != 2)
return emitError("lhs, rhs, and output must have rank 2");
if (lhsType.getElementType() != rhsType.getElementType() || lhsType.getElementType() != outputType.getElementType())
return emitError("lhs, rhs, and output must have the same element type");
if (lhsShape != rhsShape)
return emitError("lhs and rhs vector shapes must match");
if (lhsShape[0] != 1 || lhsShape[1] <= 0)
return emitError("lhs and rhs vector shape must be (1, N) with N > 0");
if (outputShape[0] != 1 || outputShape[1] != 1)
return emitError("output shape must be (1, 1)");
return success();
}
LogicalResult SpatVAddOp::verify() {
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
@@ -1019,27 +1019,6 @@ std::optional<IndexedIndexPattern> getIndexedIndexPattern(ArrayRef<int64_t> valu
return std::nullopt;
}
Value createAffineApplyOrConstant(
MaterializerState& state, Location loc, AffineMap map, ValueRange operands, Operation* anchor) {
SmallVector<Attribute> operandConstants;
operandConstants.reserve(operands.size());
for (Value operand : operands) {
auto constantValue = getConstantIntValue(operand);
if (!constantValue)
return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult();
operandConstants.push_back(state.rewriter.getIndexAttr(*constantValue));
}
SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front());
if (constantResult)
return createIndexConstant(state, anchor, constantResult.getInt());
}
return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult();
}
Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) {
MLIRContext* context = state.func.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
@@ -1054,7 +1033,7 @@ Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern
}
AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
return createAffineApplyOrConstant(state, loc, map, ValueRange {index}, state.func);
return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {index}, state.func);
}
Value createIndexedIndexValue(
@@ -3396,7 +3375,7 @@ Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targe
int64_t laneCount = static_cast<int64_t>(targetClass.cpus.size());
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1);
return createAffineApplyOrConstant(state, loc, map, ValueRange {slotIndex, *laneArg}, state.func);
return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}, state.func);
}
Value createBatchClassRunSourceLane(MaterializerState& state,