add pim.vmm verifier and fix vmm lowering
reuse code for subviews
This commit is contained in:
@@ -389,6 +389,7 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
}
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -77,6 +78,22 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreI
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) {
|
||||
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
|
||||
if (weightIndex >= coreOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
|
||||
if (weightIndex >= coreBatchOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult PimSendTensorOp::verify() {
|
||||
@@ -104,6 +121,47 @@ LogicalResult PimReceiveTensorBatchOp::verify() {
|
||||
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimVMMOp::verify() {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
return failure();
|
||||
|
||||
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex");
|
||||
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
|
||||
|
||||
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
|
||||
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
|
||||
if (!vectorType || !outputType)
|
||||
return emitError("input and output must be shaped types");
|
||||
|
||||
ArrayRef<int64_t> vectorShape = vectorType.getShape();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
|
||||
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
||||
return emitError("matrix, vector and output must have rank 2");
|
||||
|
||||
int64_t N = matrixShape[0];
|
||||
int64_t M = matrixShape[1];
|
||||
if (N <= 0 || M <= 0)
|
||||
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
||||
if (N > static_cast<int64_t>(crossbarSize) || M > static_cast<int64_t>(crossbarSize))
|
||||
return emitError("matrix dimensions must fit in one crossbar");
|
||||
|
||||
int64_t vector1 = vectorShape[0];
|
||||
int64_t vectorWidth = vectorShape[1];
|
||||
if (vector1 != 1 || vectorWidth != static_cast<int64_t>(crossbarSize))
|
||||
return emitError("vector shape must be (1, crossbar-size)");
|
||||
|
||||
int64_t output1 = outputShape[0];
|
||||
int64_t outputWidth = outputShape[1];
|
||||
if (output1 != 1 || outputWidth != static_cast<int64_t>(crossbarSize))
|
||||
return emitError("output shape must be (1, crossbar-size)");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimConcatOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
|
||||
@@ -105,6 +105,37 @@ struct MemCopyDevToHostOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct MemCopyOpInterface : DstBufferizableOpInterfaceExternalModel<MemCopyOpInterface, PimMemCopyOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memCopyOp = cast<PimMemCopyOp>(op);
|
||||
|
||||
auto targetOpt = getBufferOrValue(rewriter, memCopyOp.getTarget(), options, state);
|
||||
if (failed(targetOpt))
|
||||
return failure();
|
||||
|
||||
auto sourceOpt = getBufferOrValue(rewriter, memCopyOp.getSource(), options, state);
|
||||
if (failed(sourceOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyOp>(rewriter,
|
||||
memCopyOp,
|
||||
targetOpt->getType(),
|
||||
*targetOpt,
|
||||
*sourceOpt,
|
||||
memCopyOp.getTargetOffsetAttr(),
|
||||
memCopyOp.getSourceOffsetAttr(),
|
||||
memCopyOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInterface, PimReceiveOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
@@ -626,6 +657,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimMemCopyOp::attachInterface<MemCopyOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
||||
|
||||
|
||||
@@ -52,9 +52,10 @@ static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
||||
printer << " ";
|
||||
printer.printOperand(op.getInput());
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer.printOptionalAttrDict(op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(),
|
||||
op.getSourceCoreIdsAttrName().getValue(),
|
||||
op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getInput().getType());
|
||||
}
|
||||
@@ -62,9 +63,10 @@ static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
||||
template <typename TensorReceiveOpTy>
|
||||
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer.printOptionalAttrDict(op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(),
|
||||
op.getSourceCoreIdsAttrName().getValue(),
|
||||
op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getOutput().getType());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user