add pim.vmm verifier and fix vmm lowering
reuse code for subviews
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -77,6 +78,22 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreI
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) {
|
||||
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
|
||||
if (weightIndex >= coreOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
|
||||
if (weightIndex >= coreBatchOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult PimSendTensorOp::verify() {
|
||||
@@ -104,6 +121,47 @@ LogicalResult PimReceiveTensorBatchOp::verify() {
|
||||
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimVMMOp::verify() {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
return failure();
|
||||
|
||||
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex");
|
||||
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
|
||||
|
||||
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
|
||||
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
|
||||
if (!vectorType || !outputType)
|
||||
return emitError("input and output must be shaped types");
|
||||
|
||||
ArrayRef<int64_t> vectorShape = vectorType.getShape();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
|
||||
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
||||
return emitError("matrix, vector and output must have rank 2");
|
||||
|
||||
int64_t N = matrixShape[0];
|
||||
int64_t M = matrixShape[1];
|
||||
if (N <= 0 || M <= 0)
|
||||
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
||||
if (N > static_cast<int64_t>(crossbarSize) || M > static_cast<int64_t>(crossbarSize))
|
||||
return emitError("matrix dimensions must fit in one crossbar");
|
||||
|
||||
int64_t vector1 = vectorShape[0];
|
||||
int64_t vectorWidth = vectorShape[1];
|
||||
if (vector1 != 1 || vectorWidth != static_cast<int64_t>(crossbarSize))
|
||||
return emitError("vector shape must be (1, crossbar-size)");
|
||||
|
||||
int64_t output1 = outputShape[0];
|
||||
int64_t outputWidth = outputShape[1];
|
||||
if (output1 != 1 || outputWidth != static_cast<int64_t>(crossbarSize))
|
||||
return emitError("output shape must be (1, crossbar-size)");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimConcatOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
|
||||
Reference in New Issue
Block a user