Files
Raptor/src/PIM/Dialect/Pim/PimOpsVerify.cpp
T
2026-05-22 15:23:48 +02:00

304 lines
12 KiB
C++

#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace pim {
namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<PimMemCopyHostToDevOp>(op))
return operandIndex == 3;
if (isa<PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1;
if (isa<PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
static Region* getParentRegion(Value value) {
if (auto blockArgument = dyn_cast<BlockArgument>(value))
return blockArgument.getParentRegion();
Operation* definingOp = value.getDefiningOp();
return definingOp ? definingOp->getParentRegion() : nullptr;
}
static bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
}
static bool isConstantExternalValue(Value value) {
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
if (definingOp->hasTrait<OpTrait::ConstantLike>())
return true;
auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(definingOp);
if (!getGlobalOp)
return false;
auto moduleOp = definingOp->getParentOfType<ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
return globalOp && globalOp.getConstant();
}
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
bool hasFailure = false;
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)
|| isExplicitHostOperand(op, operand.getOperandNumber()))
continue;
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
<< kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc())
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
hasFailure = true;
}
});
return success(!hasFailure);
}
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
}
static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type rhs, StringRef message) {
auto lhsShaped = dyn_cast<ShapedType>(lhs);
auto rhsShaped = dyn_cast<ShapedType>(rhs);
if (!lhsShaped || !rhsShaped || !haveSameShapedContainerKind(lhs, rhs))
return op->emitError(message);
if (lhsShaped.getElementType() != rhsShaped.getElementType() || lhsShaped.getShape() != rhsShaped.getShape())
return op->emitError(message);
return success();
}
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor or memref";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % static_cast<int64_t>(coreIds.size()) != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the number of core ids";
return success();
}
static LogicalResult
verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
if (!coreBatchOp)
return op->emitError() << kind << " must be nested inside pim.core_batch";
int32_t laneCount = coreBatchOp.getLaneCount();
if (laneCount <= 0)
return op->emitError() << kind << " requires a positive parent laneCount";
if (coreIds.size() % static_cast<size_t>(laneCount) != 0)
return op->emitError() << kind << " core id count must be divisible by the parent laneCount";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor or memref";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t chunkCount = static_cast<int64_t>(coreIds.size()) / laneCount;
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % chunkCount != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
return success();
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
auto shapedType = dyn_cast<ShapedType>(weight.getType());
if (!shapedType)
return failure();
return shapedType.getShape();
}
} // namespace
LogicalResult PimCoreOp::verify() {
Block& block = getBody().front();
if (block.getNumArguments() != getWeights().size())
return emitError("core body must have one block argument per weight");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core weight block argument types must match weight operand types exactly");
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
}
LogicalResult PimCoreBatchOp::verify() {
if (getLaneCount() <= 0)
return emitError("laneCount must be positive");
Block& block = getBody().front();
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size();
if (block.getNumArguments() != expectedArgCount)
return emitError("core_batch body must have lane, weight, and input block arguments");
if (!getLaneArgument().getType().isIndex())
return emitError("core_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core_batch weight block argument types must match weight operand types exactly");
for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
if (getInputArgument(inputIndex).getType() != input.getType())
return emitError("core_batch input block argument types must match input operand types exactly");
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
}
LogicalResult PimSendTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
}
LogicalResult PimSendTensorBatchOp::verify() {
return verifyTensorBatchCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor_batch");
}
LogicalResult PimReceiveTensorOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
}
LogicalResult PimReceiveTensorBatchOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
return verifyTensorBatchCommunication(
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
}
LogicalResult PimVMMOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
auto matrixShapeOpt = getWeightShapeForVMM(getWeight());
if (failed(matrixShapeOpt))
return emitError("weight must be a shaped value");
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
if (!vectorType || !outputType)
return emitError("input and output must be shaped types");
ArrayRef<int64_t> vectorShape = vectorType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitError("matrix, vector and output must have rank 2");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
if (N > static_cast<int64_t>(crossbarSize) || M > static_cast<int64_t>(crossbarSize))
return emitError("matrix dimensions must fit in one crossbar");
int64_t vector1 = vectorShape[0];
int64_t vectorWidth = vectorShape[1];
if (vector1 != 1 || vectorWidth != static_cast<int64_t>(crossbarSize))
return emitError("vector shape must be (1, crossbar-size)");
int64_t output1 = outputShape[0];
int64_t outputWidth = outputShape[1];
if (output1 != 1 || outputWidth != static_cast<int64_t>(crossbarSize))
return emitError("output shape must be (1, crossbar-size)");
return success();
}
LogicalResult PimConcatOp::verify() {
if (getInputs().empty())
return emitError("requires at least one input");
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
if (!outputType || !outputType.hasRank())
return emitError("output must be a ranked shaped type");
int64_t axis = getAxis();
int64_t rank = outputType.getRank();
if (axis < 0 || axis >= rank)
return emitError("axis must be within the output rank");
int64_t concatenatedDimSize = 0;
bool concatenatedDimDynamic = false;
Type outputElementType = outputType.getElementType();
for (Value input : getInputs()) {
auto inputType = dyn_cast<ShapedType>(input.getType());
if (!inputType || !inputType.hasRank())
return emitError("inputs must be ranked shaped types");
if (!haveSameShapedContainerKind(input.getType(), getOutput().getType()))
return emitError("inputs and output must use the same shaped container kind");
if (inputType.getRank() != rank)
return emitError("all inputs must have the same rank as the output");
if (inputType.getElementType() != outputElementType)
return emitError("all inputs must have the same element type as the output");
for (int64_t dim = 0; dim < rank; ++dim) {
if (dim == axis)
continue;
int64_t inputDim = inputType.getDimSize(dim);
int64_t outputDim = outputType.getDimSize(dim);
if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim)
return emitError("non-concatenated dimensions must match the output shape");
}
int64_t inputConcatDim = inputType.getDimSize(axis);
if (ShapedType::isDynamic(inputConcatDim)) {
concatenatedDimDynamic = true;
continue;
}
concatenatedDimSize += inputConcatDim;
}
int64_t outputConcatDim = outputType.getDimSize(axis);
if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim)
return emitError("output concatenated dimension must equal the sum of input sizes");
return success();
}
} // namespace pim
} // namespace onnx_mlir