refactorone
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-20 19:06:41 +02:00
parent f56c4159b5
commit a50e77ff38
50 changed files with 3420 additions and 1187 deletions
+86 -16
View File
@@ -1,5 +1,7 @@
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/Support/LogicalResult.h"
@@ -14,6 +16,52 @@ namespace pim {
namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<PimMemCopyHostToDevOp>(op))
return operandIndex == 3;
if (isa<PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1;
if (isa<PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
static Region* getParentRegion(Value value) {
if (auto blockArgument = dyn_cast<BlockArgument>(value))
return blockArgument.getParentRegion();
Operation* definingOp = value.getDefiningOp();
return definingOp ? definingOp->getParentRegion() : nullptr;
}
static bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
}
static bool isConstantExternalValue(Value value) {
Operation* definingOp = value.getDefiningOp();
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
}
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
bool hasFailure = false;
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)
|| isExplicitHostOperand(op, operand.getOperandNumber()))
continue;
InFlightDiagnostic diagnostic =
ownerOp->emitOpError() << kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc())
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
hasFailure = true;
}
});
return success(!hasFailure);
}
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
}
@@ -78,24 +126,46 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreI
return success();
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) {
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
if (weightIndex >= coreOp.getWeights().size())
return failure();
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
}
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
if (weightIndex >= coreBatchOp.getWeights().size())
return failure();
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
}
return failure();
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
auto shapedType = dyn_cast<ShapedType>(weight.getType());
if (!shapedType)
return failure();
return shapedType.getShape();
}
} // namespace
LogicalResult PimCoreOp::verify() {
Block& block = getBody().front();
if (block.getNumArguments() != getWeights().size())
return emitError("core body must have one block argument per weight");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core weight block argument types must match weight operand types exactly");
}
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
}
LogicalResult PimCoreBatchOp::verify() {
if (getLaneCount() <= 0)
return emitError("laneCount must be positive");
Block& block = getBody().front();
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size();
if (block.getNumArguments() != expectedArgCount)
return emitError("core_batch body must have lane, weight, and input block arguments");
if (!getLaneArgument().getType().isIndex())
return emitError("core_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core_batch weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
if (getInputArgument(inputIndex).getType() != input.getType())
return emitError("core_batch input block argument types must match input operand types exactly");
}
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
}
LogicalResult PimSendTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
}
@@ -126,9 +196,9 @@ LogicalResult PimVMMOp::verify() {
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex());
auto matrixShapeOpt = getWeightShapeForVMM(getWeight());
if (failed(matrixShapeOpt))
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex");
return emitError("weight must be a shaped value");
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
auto vectorType = dyn_cast<ShapedType>(getInput().getType());