Dynamic gemm/conv

This commit is contained in:
ilgeco
2026-05-28 18:00:14 +02:00
parent cbf7b235f1
commit 1ab489fe0a
17 changed files with 704 additions and 69 deletions
@@ -524,6 +524,39 @@ struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<BinaryDstO
}
};
struct VVDMulOpInterface : DstBufferizableOpInterfaceExternalModel<VVDMulOpInterface, PimVVDMulOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto vvdmulOp = cast<PimVVDMulOp>(op);
auto lhsOpt = getBufferOrValue(rewriter, vvdmulOp.getLhs(), options, state);
if (failed(lhsOpt))
return failure();
auto rhsOpt = getBufferOrValue(rewriter, vvdmulOp.getRhs(), options, state);
if (failed(rhsOpt))
return failure();
auto outputBufferOpt = getBufferOrValue(rewriter, vvdmulOp.getOutputBuffer(), options, state);
if (failed(outputBufferOpt))
return failure();
Value contiguousLhs = materializeContiguousMemRef(*lhsOpt, op->getLoc(), rewriter);
Value contiguousRhs = materializeContiguousMemRef(*rhsOpt, op->getLoc(), rewriter);
Value contiguousOutput = allocateContiguousMemRefLike(*outputBufferOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVVDMulOp>(
rewriter, op, contiguousOutput.getType(), contiguousLhs, contiguousRhs, contiguousOutput);
return success();
}
};
template <typename OpTy>
struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel<UnaryDstOpInterface<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
@@ -576,6 +609,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimVVSubOp::attachInterface<BinaryDstOpInterface<PimVVSubOp>>(*ctx);
PimVVMulOp::attachInterface<BinaryDstOpInterface<PimVVMulOp>>(*ctx);
PimVVMaxOp::attachInterface<BinaryDstOpInterface<PimVVMaxOp>>(*ctx);
PimVVDMulOp::attachInterface<VVDMulOpInterface>(*ctx);
PimVAvgOp::attachInterface<UnaryDstOpInterface<PimVAvgOp>>(*ctx);
PimVReluOp::attachInterface<UnaryDstOpInterface<PimVReluOp>>(*ctx);
+1
View File
@@ -6,6 +6,7 @@ add_pim_library(SpatialOps
SpatialOpsAsm.cpp
SpatialOpsVerify.cpp
SpatialOpsCanonicalization.cpp
${PIM_SRC_ROOT}/Conversion/ONNXToSpatial/CompileTime.cpp
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
+19
View File
@@ -219,6 +219,25 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
}];
}
def SpatVVDMulOp : SpatOp<"vvdmul", []> {
let summary = "Dot product between two runtime vectors";
let arguments = (ins
SpatTensor:$lhs,
SpatTensor:$rhs
);
let results = (outs
SpatTensor:$output
);
let hasVerifier = 1;
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
}];
}
def SpatVAddOp : SpatOp<"vadd", []> {
let summary = "Element-wise addition between two tensors; rhs must match lhs or be 1x1";
@@ -249,6 +249,32 @@ LogicalResult SpatVMMOp::verify() {
return success();
}
LogicalResult SpatVVDMulOp::verify() {
auto lhsType = dyn_cast<ShapedType>(getLhs().getType());
auto rhsType = dyn_cast<ShapedType>(getRhs().getType());
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
if (!lhsType || !rhsType || !outputType)
return emitError("lhs, rhs, and output must be shaped values");
if (!lhsType.hasRank() || !rhsType.hasRank() || !outputType.hasRank())
return emitError("lhs, rhs, and output must have ranked types");
ArrayRef<int64_t> lhsShape = lhsType.getShape();
ArrayRef<int64_t> rhsShape = rhsType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
if (lhsShape.size() != 2 || rhsShape.size() != 2 || outputShape.size() != 2)
return emitError("lhs, rhs, and output must have rank 2");
if (lhsType.getElementType() != rhsType.getElementType() || lhsType.getElementType() != outputType.getElementType())
return emitError("lhs, rhs, and output must have the same element type");
if (lhsShape != rhsShape)
return emitError("lhs and rhs vector shapes must match");
if (lhsShape[0] != 1 || lhsShape[1] <= 0)
return emitError("lhs and rhs vector shape must be (1, N) with N > 0");
if (outputShape[0] != 1 || outputShape[1] != 1)
return emitError("output shape must be (1, 1)");
return success();
}
LogicalResult SpatVAddOp::verify() {
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
@@ -1019,27 +1019,6 @@ std::optional<IndexedIndexPattern> getIndexedIndexPattern(ArrayRef<int64_t> valu
return std::nullopt;
}
Value createAffineApplyOrConstant(
MaterializerState& state, Location loc, AffineMap map, ValueRange operands, Operation* anchor) {
SmallVector<Attribute> operandConstants;
operandConstants.reserve(operands.size());
for (Value operand : operands) {
auto constantValue = getConstantIntValue(operand);
if (!constantValue)
return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult();
operandConstants.push_back(state.rewriter.getIndexAttr(*constantValue));
}
SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults))) {
auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front());
if (constantResult)
return createIndexConstant(state, anchor, constantResult.getInt());
}
return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult();
}
Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) {
MLIRContext* context = state.func.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
@@ -1054,7 +1033,7 @@ Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern
}
AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr);
return createAffineApplyOrConstant(state, loc, map, ValueRange {index}, state.func);
return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {index}, state.func);
}
Value createIndexedIndexValue(
@@ -3396,7 +3375,7 @@ Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targe
int64_t laneCount = static_cast<int64_t>(targetClass.cpus.size());
AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1);
return createAffineApplyOrConstant(state, loc, map, ValueRange {slotIndex, *laneArg}, state.func);
return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}, state.func);
}
Value createBatchClassRunSourceLane(MaterializerState& state,