add pim.vmm verifier and fix vmm lowering

reuse code for subviews
This commit is contained in:
NiccoloN
2026-05-12 15:13:50 +02:00
parent 628dc630a4
commit 4f3570520c
15 changed files with 358 additions and 207 deletions
+58
View File
@@ -4,6 +4,7 @@
#include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
@@ -77,6 +78,22 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreI
return success();
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) {
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
if (weightIndex >= coreOp.getWeights().size())
return failure();
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
}
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
if (weightIndex >= coreBatchOp.getWeights().size())
return failure();
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
}
return failure();
}
} // namespace
LogicalResult PimSendTensorOp::verify() {
@@ -104,6 +121,47 @@ LogicalResult PimReceiveTensorBatchOp::verify() {
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
}
LogicalResult PimVMMOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex");
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
if (!vectorType || !outputType)
return emitError("input and output must be shaped types");
ArrayRef<int64_t> vectorShape = vectorType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitError("matrix, vector and output must have rank 2");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
if (N > static_cast<int64_t>(crossbarSize) || M > static_cast<int64_t>(crossbarSize))
return emitError("matrix dimensions must fit in one crossbar");
int64_t vector1 = vectorShape[0];
int64_t vectorWidth = vectorShape[1];
if (vector1 != 1 || vectorWidth != static_cast<int64_t>(crossbarSize))
return emitError("vector shape must be (1, crossbar-size)");
int64_t output1 = outputShape[0];
int64_t outputWidth = outputShape[1];
if (output1 != 1 || outputWidth != static_cast<int64_t>(crossbarSize))
return emitError("output shape must be (1, crossbar-size)");
return success();
}
LogicalResult PimConcatOp::verify() {
if (getInputs().empty())
return emitError("requires at least one input");