refactorone
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-20 19:06:41 +02:00
parent f56c4159b5
commit a50e77ff38
50 changed files with 3420 additions and 1187 deletions
+231 -116
View File
@@ -1,6 +1,9 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
@@ -82,20 +85,11 @@ inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
return success();
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
if (auto computeOp = weightedOp->getParentOfType<SpatCompute>())
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
if (auto coreOp = weightedOp->getParentOfType<pim::PimCoreOp>())
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
if (auto batchOp = weightedOp->getParentOfType<SpatComputeBatch>()) {
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
return failure();
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
}
return failure();
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
auto shapedType = dyn_cast<ShapedType>(weight.getType());
if (!shapedType)
return failure();
return shapedType.getShape();
}
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
@@ -105,15 +99,86 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
return batchOp.getLaneCount();
}
static LogicalResult verifyTensorChannelSizes(Operation* op,
Type type,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
StringRef kind) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
if (batchOp.getNumResults() == 0)
return false;
auto blockArg = dyn_cast<BlockArgument>(value);
if (!blockArg || blockArg.getOwner() != &batchOp.getBody().front())
return false;
unsigned argNumber = blockArg.getArgNumber();
unsigned firstOutputArg = batchOp.getOutputArgument(0).getArgNumber();
return argNumber >= firstOutputArg && argNumber < firstOutputArg + batchOp.getNumResults();
}
static bool isConstantIndexLike(Value value) {
APInt constantValue;
return matchPattern(value, m_ConstantInt(&constantValue));
}
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
if (value == laneArg || isConstantIndexLike(value))
return true;
auto addOp = value.getDefiningOp<arith::AddIOp>();
if (!addOp)
return false;
return (addOp.getLhs() == laneArg && isConstantIndexLike(addOp.getRhs()))
|| (addOp.getRhs() == laneArg && isConstantIndexLike(addOp.getLhs()));
}
static LogicalResult
verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgument laneArg, StringRef kind) {
auto sourceType = dyn_cast<RankedTensorType>(sliceOp.getSource().getType());
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
if (!sliceOp.hasUnitStride())
return sliceOp.emitOpError() << kind << " requires unit strides";
for (int64_t size : sliceOp.getStaticSizes())
if (ShapedType::isDynamic(size))
return sliceOp.emitOpError() << kind << " requires static slice sizes";
auto offsets = sliceOp.getOffsets();
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
if (!supported)
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
}
return success();
}
static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::ParallelInsertSliceOp sliceOp,
BlockArgument laneArg,
StringRef kind) {
RankedTensorType sourceType = sliceOp.getSourceType();
RankedTensorType destType = sliceOp.getDestType();
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
if (!sliceOp.hasUnitStride())
return sliceOp.emitOpError() << kind << " requires unit strides";
for (int64_t size : sliceOp.getStaticSizes())
if (ShapedType::isDynamic(size))
return sliceOp.emitOpError() << kind << " requires static slice sizes";
auto offsets = sliceOp.getOffsets();
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
if (!supported)
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
}
return success();
}
static LogicalResult verifyTensorChannelSizes(
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
if (channelIds.empty())
if (channelCount == 0)
return op->emitError() << kind << " must carry at least one chunk";
auto shapedType = dyn_cast<ShapedType>(type);
@@ -125,40 +190,34 @@ static LogicalResult verifyTensorChannelSizes(Operation* op,
return op->emitError() << kind << " requires byte-sized elements";
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % static_cast<int64_t>(channelIds.size()) != 0)
if (totalBytes % static_cast<int64_t>(channelCount) != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
return success();
}
static LogicalResult verifyBatchChannelSizes(Operation* op,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
static LogicalResult
verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) {
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch");
if (channelIds.size() != static_cast<size_t>(*laneCount))
if (channelCount != static_cast<size_t>(*laneCount))
return op->emitError("channel metadata length must match parent laneCount");
return success();
}
static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
Type type,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
StringRef kind) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
static LogicalResult verifyTensorBatchChannelSizes(
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch");
if (channelIds.empty() || channelIds.size() % static_cast<size_t>(*laneCount) != 0)
if (channelCount == 0 || channelCount % static_cast<size_t>(*laneCount) != 0)
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
auto shapedType = dyn_cast<ShapedType>(type);
@@ -169,7 +228,7 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t chunkCount = static_cast<int64_t>(channelIds.size()) / *laneCount;
int64_t chunkCount = static_cast<int64_t>(channelCount) / *laneCount;
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % chunkCount != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
@@ -177,28 +236,59 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
return success();
}
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return op->emitError("body must terminate with spat.yield");
if (outputTypes.empty()) {
static Region* getParentRegion(Value value) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent();
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion();
return nullptr;
}
static bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
}
static bool isConstantExternalValue(Value value) {
Operation* definingOp = value.getDefiningOp();
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
}
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
bool hasFailure = false;
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
continue;
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
<< kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber()
<< " is used by " << op->getName();
hasFailure = true;
}
});
return success(!hasFailure);
}
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
if (batchOp.getNumResults() == 0) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return batchOp.emitError("resultless compute_batch body must terminate with spat.yield");
if (yieldOp.getNumOperands() != 0)
return op->emitError("body yield must be empty when compute_batch has no results");
return batchOp.emitError("resultless compute_batch body yield must be empty");
}
else {
if (yieldOp.getNumOperands() != 1)
return op->emitError("body yield must produce exactly one value");
if (yieldOp.getOperand(0).getType() != outputTypes[0])
return op->emitError("body yield type must match output type");
else if (!isa_and_nonnull<SpatInParallelOp>(block.getTerminator())) {
return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel");
}
BlockArgument laneArg = batchOp.getLaneArgument();
for (auto& bodyOp : block) {
if (auto wvmm = dyn_cast<SpatVMMOp>(&bodyOp))
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
if (auto wmvm = dyn_cast<SpatMVMOp>(&bodyOp))
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp))
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, laneArg, "tensor.extract_slice")))
return failure();
}
return success();
}
@@ -206,9 +296,9 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
} // namespace
LogicalResult SpatMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
if (failed(matrixShapeOpt))
return emitError("SpatMVMOp was not within a SpatCompute or Core op");
return emitError("weight must be a shaped value");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
@@ -221,9 +311,9 @@ LogicalResult SpatMVMOp::verify() {
}
LogicalResult SpatVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
if (failed(matrixShapeOpt))
return emitError("SpatVMMOp was not within a SpatCompute or Core op");
return emitError("weight must be a shaped value");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
@@ -347,13 +437,26 @@ LogicalResult verifyComputeResultsUses(Operation* op) {
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
});
})) {
return op->emitError("ComputeResult used directly inside another Compute" );
return op->emitError("ComputeResult used directly inside another Compute");
}
return success();
}
LogicalResult SpatCompute::verify() {
auto& block = getBody().front();
unsigned expectedArgCount = getWeights().size() + getInputs().size();
if (block.getNumArguments() != expectedArgCount)
return emitError("compute body must have weight and input block arguments");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("compute weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
if (getInputArgument(inputIndex).getType() != input.getType())
return emitError("compute input block argument types must match input operand types exactly");
}
if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
@@ -386,9 +489,11 @@ LogicalResult SpatCompute::verify() {
}
}
for (auto arg : block.getArguments())
if (arg.use_empty())
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
if (getInputArgument(inputIndex).use_empty())
return emitError("ComputeOp block argument is not used");
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
return failure();
if (failed(verifyComputeResultsUses(this->getOperation())))
return failure();
return success();
@@ -397,44 +502,46 @@ LogicalResult SpatCompute::verify() {
LogicalResult SpatChannelSendTensorOp::verify() {
return verifyTensorChannelSizes(getOperation(),
getInput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_send_tensor");
}
LogicalResult SpatChannelReceiveTensorOp::verify() {
return verifyTensorChannelSizes(getOperation(),
getOutput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_receive_tensor");
}
LogicalResult SpatChannelSendBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
return verifyBatchChannelSizes(
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
}
LogicalResult SpatChannelSendTensorBatchOp::verify() {
return verifyTensorBatchChannelSizes(getOperation(),
getInput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_send_tensor_batch");
}
LogicalResult SpatChannelReceiveBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
return verifyBatchChannelSizes(
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
}
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
return verifyTensorBatchChannelSizes(getOperation(),
getOutput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_receive_tensor_batch");
}
@@ -444,35 +551,6 @@ LogicalResult SpatComputeBatch::verify() {
return emitError("laneCount must be positive");
auto laneCountSz = static_cast<size_t>(count);
if (getWeights().size() % laneCountSz != 0)
return emitError("number of weights must be a multiple of laneCount");
if (!getInputs().empty() && getInputs().size() != laneCountSz)
return emitError("number of inputs must be either 0 or laneCount");
if (!getOutputs().empty() && getOutputs().size() != laneCountSz)
return emitError("number of outputs must be either 0 or laneCount");
size_t weightsPerLane = getWeights().size() / laneCountSz;
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) {
Type weightType = getWeights()[weightIndex].getType();
for (size_t lane = 1; lane < laneCountSz; ++lane)
if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType)
return emitError("corresponding weights across lanes must have the same type");
}
if (!getInputs().empty()) {
Type inputType = getInputs()[0].getType();
for (Value in : getInputs().drop_front())
if (in.getType() != inputType)
return emitError("all inputs must have the same type");
}
if (!getOutputs().empty()) {
Type outputType = getOutputs()[0].getType();
for (Value out : getOutputs().drop_front())
if (out.getType() != outputType)
return emitError("all outputs must have the same type");
}
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
@@ -482,27 +560,64 @@ LogicalResult SpatComputeBatch::verify() {
return emitError("compute_batch coreIds array length must match laneCount");
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
return emitError("compute_batch coreIds values must be non-negative");
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
DenseSet<int32_t> seenCoreIds;
for (int32_t coreId : coreIdsAttr.asArrayRef())
if (!seenCoreIds.insert(coreId).second)
return emitError("compute_batch coreIds values must be distinct");
return emitError("compute_batch coreIds values must be unique");
}
Block& block = getBody().front();
if (getInputs().empty()) {
if (block.getNumArguments() != 0)
return emitError("compute_batch body must have no block arguments when there are no inputs");
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
if (block.getNumArguments() != expectedArgCount)
return emitError("compute_batch body must have lane, weight, input, and output block arguments");
if (!getLaneArgument().getType().isIndex())
return emitError("compute_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("compute_batch weight block argument types must match weight operand types exactly");
}
else {
if (block.getNumArguments() != 1)
return emitError("compute_batch body must have exactly one block argument");
if (block.getArgument(0).getType() != getInputs()[0].getType())
return emitError("body block argument type must match input type");
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
BlockArgument blockArg = getInputArgument(inputIndex);
if (blockArg.getType() != input.getType())
return emitError("compute_batch input block argument types must match input operand types exactly");
}
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
BlockArgument blockArg = getOutputArgument(resultIndex);
if (blockArg.getType() != resultType)
return emitError("compute_batch output block argument types must match result types exactly");
}
if (failed(verifyComputeResultsUses(this->getOperation())))
return failure();
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
return failure();
return verifyBatchBody(*this, block);
}
LogicalResult SpatInParallelOp::verify() {
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
if (!batchOp)
return emitOpError("expected spat.compute_batch parent");
if (batchOp.getNumResults() == 0)
return emitOpError("requires a resultful spat.compute_batch parent");
BlockArgument laneArg = batchOp.getLaneArgument();
for (Operation& op : getRegion().front().getOperations()) {
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
if (!insertSliceOp)
return emitOpError("expected only tensor.parallel_insert_slice ops");
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, laneArg, "tensor.parallel_insert_slice")))
return failure();
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
for (OpOperand& destination : destinations)
if (!isBatchOutputArgument(batchOp, destination.get()))
return op.emitOpError("may only insert into a compute_batch output block argument");
}
return success();
}
} // namespace spatial