add pim.vmm verifier and fix vmm lowering

reuse code for subviews
This commit is contained in:
NiccoloN
2026-05-12 15:13:50 +02:00
parent 628dc630a4
commit 4f3570520c
15 changed files with 358 additions and 207 deletions
+1
View File
@@ -389,6 +389,7 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
}
}];
let hasVerifier = 1;
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
+58
View File
@@ -4,6 +4,7 @@
#include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
@@ -77,6 +78,22 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreI
return success();
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) {
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
if (weightIndex >= coreOp.getWeights().size())
return failure();
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
}
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
if (weightIndex >= coreBatchOp.getWeights().size())
return failure();
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
}
return failure();
}
} // namespace
LogicalResult PimSendTensorOp::verify() {
@@ -104,6 +121,47 @@ LogicalResult PimReceiveTensorBatchOp::verify() {
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
}
LogicalResult PimVMMOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex");
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
if (!vectorType || !outputType)
return emitError("input and output must be shaped types");
ArrayRef<int64_t> vectorShape = vectorType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitError("matrix, vector and output must have rank 2");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
if (N > static_cast<int64_t>(crossbarSize) || M > static_cast<int64_t>(crossbarSize))
return emitError("matrix dimensions must fit in one crossbar");
int64_t vector1 = vectorShape[0];
int64_t vectorWidth = vectorShape[1];
if (vector1 != 1 || vectorWidth != static_cast<int64_t>(crossbarSize))
return emitError("vector shape must be (1, crossbar-size)");
int64_t output1 = outputShape[0];
int64_t outputWidth = outputShape[1];
if (output1 != 1 || outputWidth != static_cast<int64_t>(crossbarSize))
return emitError("output shape must be (1, crossbar-size)");
return success();
}
LogicalResult PimConcatOp::verify() {
if (getInputs().empty())
return emitError("requires at least one input");
@@ -105,6 +105,37 @@ struct MemCopyDevToHostOpInterface
}
};
struct MemCopyOpInterface : DstBufferizableOpInterfaceExternalModel<MemCopyOpInterface, PimMemCopyOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto memCopyOp = cast<PimMemCopyOp>(op);
auto targetOpt = getBufferOrValue(rewriter, memCopyOp.getTarget(), options, state);
if (failed(targetOpt))
return failure();
auto sourceOpt = getBufferOrValue(rewriter, memCopyOp.getSource(), options, state);
if (failed(sourceOpt))
return failure();
replaceOpWithNewBufferizedOp<PimMemCopyOp>(rewriter,
memCopyOp,
targetOpt->getType(),
*targetOpt,
*sourceOpt,
memCopyOp.getTargetOffsetAttr(),
memCopyOp.getSourceOffsetAttr(),
memCopyOp.getSizeAttr());
return success();
}
};
struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInterface, PimReceiveOp> {
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
@@ -626,6 +657,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
PimMemCopyOp::attachInterface<MemCopyOpInterface>(*ctx);
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
+8 -6
View File
@@ -52,9 +52,10 @@ static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
printer << " ";
printer.printOperand(op.getInput());
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
printer.printOptionalAttrDict(
op->getAttrs(),
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
printer.printOptionalAttrDict(op->getAttrs(),
{op.getChannelIdsAttrName().getValue(),
op.getSourceCoreIdsAttrName().getValue(),
op.getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(op.getInput().getType());
}
@@ -62,9 +63,10 @@ static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
template <typename TensorReceiveOpTy>
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
printer.printOptionalAttrDict(
op->getAttrs(),
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
printer.printOptionalAttrDict(op->getAttrs(),
{op.getChannelIdsAttrName().getValue(),
op.getSourceCoreIdsAttrName().getValue(),
op.getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(op.getOutput().getType());
}