big cleanup: remove remaining pim many operations, simplify bufferization logic
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-11 14:38:13 +02:00
parent b1272d2283
commit 5ff364027b
12 changed files with 390 additions and 1164 deletions
-180
View File
@@ -13,12 +13,6 @@ namespace pim {
namespace {
static LogicalResult verifyManyCommunicationSizes(Operation* op, ArrayRef<int32_t> coreIds, size_t valueCount) {
if (coreIds.size() != valueCount)
return op->emitError("core id metadata length must match the number of values");
return success();
}
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
}
@@ -33,28 +27,6 @@ static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type r
return success();
}
static LogicalResult verifyManyCommunicationTypes(Operation* op, TypeRange types, StringRef kind) {
if (types.empty())
return op->emitError() << kind << " must carry at least one value";
Type firstType = types.front();
auto firstShapedType = dyn_cast<ShapedType>(firstType);
bool firstIsTensor = isa<RankedTensorType>(firstType);
bool firstIsMemRef = isa<MemRefType>(firstType);
for (Type type : types.drop_front())
if (type != firstType) {
auto shapedType = dyn_cast<ShapedType>(type);
if (!firstShapedType || !shapedType)
return op->emitError() << kind << " values must all have the same type";
if (firstIsTensor != isa<RankedTensorType>(type) || firstIsMemRef != isa<MemRefType>(type))
return op->emitError() << kind << " values must all use the same shaped container kind";
if (firstShapedType.getElementType() != shapedType.getElementType()
|| firstShapedType.getShape() != shapedType.getShape())
return op->emitError() << kind << " values must all have the same shape and element type";
}
return success();
}
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
@@ -74,109 +46,12 @@ static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRe
return success();
}
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
if (!coreBatchOp)
return failure();
return coreBatchOp.getLaneCount();
}
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op, ArrayRef<int32_t> coreIds, size_t valueCount) {
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside pim.core_batch");
if (coreIds.size() != valueCount * static_cast<size_t>(*laneCount))
return op->emitError("core id metadata length must match the number of values times parent laneCount");
return success();
}
} // namespace
LogicalResult PimEmptyManyOp::verify() {
if (getOutputs().empty())
return emitError("must produce at least one output");
Type firstType = getOutputs().front().getType();
auto firstShapedType = dyn_cast<ShapedType>(firstType);
if (!firstShapedType || !firstShapedType.hasRank())
return emitError("outputs must all be ranked shaped types");
for (Value output : getOutputs().drop_front())
if (output.getType() != firstType)
return emitError("outputs must all have the same type");
return success();
}
LogicalResult PimMapOp::verify() {
if (getInputs().empty())
return emitError("requires at least one input");
if (getOutputs().size() != getInputs().size())
return emitError("number of outputs must match number of inputs");
Type inputType = getInputs().front().getType();
for (Value input : getInputs().drop_front())
if (input.getType() != inputType)
return emitError("all inputs must have the same type");
Type outputType = getOutputs().front().getType();
for (Value output : getOutputs().drop_front())
if (output.getType() != outputType)
return emitError("all outputs must have the same type");
Block& block = getBody().front();
if (block.getNumArguments() != 1)
return emitError("body must have exactly one block argument");
if (failed(verifyCompatibleShapedTypes(
getOperation(), block.getArgument(0).getType(), inputType, "body block argument type must match input type")))
return emitError("body block argument type must match input type");
auto yieldOp = dyn_cast_or_null<PimYieldOp>(block.getTerminator());
if (!yieldOp)
return emitError("body must terminate with pim.yield");
if (yieldOp.getNumOperands() != 1)
return emitError("body yield must produce exactly one value");
if (failed(verifyCompatibleShapedTypes(
getOperation(), yieldOp.getOperand(0).getType(), outputType, "body yield type must match output type")))
return emitError("body yield type must match output type");
return success();
}
LogicalResult PimSendManyOp::verify() {
if (failed(verifyManyCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
return failure();
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many");
}
LogicalResult PimSendTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
}
LogicalResult PimSendManyBatchOp::verify() {
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
return failure();
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many_batch");
}
LogicalResult PimReceiveManyOp::verify() {
if (getOutputBuffers().size() != getOutputs().size())
return emitError("number of output buffers must match the number of outputs");
if (failed(verifyManyCommunicationSizes(getOperation(), getSourceCoreIds(), getOutputs().size())))
return failure();
if (failed(verifyManyCommunicationTypes(getOperation(), getOutputBuffers().getTypes(), "receive_many")))
return failure();
if (failed(verifyManyCommunicationTypes(getOperation(), getOperation()->getResultTypes(), "receive_many")))
return failure();
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs()))
if (outputBuffer.getType() != output.getType())
return emitError("output buffers and outputs must have matching types");
return success();
}
LogicalResult PimReceiveTensorOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
@@ -185,61 +60,6 @@ LogicalResult PimReceiveTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
}
LogicalResult PimReceiveManyBatchOp::verify() {
if (getOutputBuffers().size() != getOutputs().size())
return emitError("number of output buffers must match the number of outputs");
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getSourceCoreIds(), getOutputs().size())))
return failure();
if (failed(verifyManyCommunicationTypes(getOperation(), getOutputBuffers().getTypes(), "receive_many_batch")))
return failure();
if (failed(verifyManyCommunicationTypes(getOperation(), getOperation()->getResultTypes(), "receive_many_batch")))
return failure();
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs()))
if (outputBuffer.getType() != output.getType())
return emitError("output buffers and outputs must have matching types");
return success();
}
LogicalResult PimExtractRowsOp::verify() {
if (getOutputBuffers().size() != getOutputs().size())
return emitError("number of output buffers must match the number of outputs");
auto inputType = dyn_cast<ShapedType>(getInput().getType());
if (!inputType || !inputType.hasRank() || inputType.getRank() != 2)
return emitError("input must be a rank-2 shaped type");
int64_t numRows = inputType.getShape()[0];
int64_t numCols = inputType.getShape()[1];
Type elementType = inputType.getElementType();
if (numRows >= 0 && static_cast<int64_t>(getOutputs().size()) != numRows)
return emitError("number of outputs must match the number of input rows");
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs())) {
if (failed(verifyCompatibleShapedTypes(
getOperation(), outputBuffer.getType(), output.getType(), "output buffers and outputs must match")))
return failure();
auto outputType = dyn_cast<ShapedType>(output.getType());
if (!outputType || !outputType.hasRank() || outputType.getRank() != 2)
return emitError("outputs must all be rank-2 shaped types");
if (!haveSameShapedContainerKind(getInput().getType(), output.getType()))
return emitError("outputs must use the same shaped container kind as the input");
if (outputType.getElementType() != elementType)
return emitError("output element types must match input element type");
auto outputShape = outputType.getShape();
if (outputShape[0] != 1)
return emitError("each output must have exactly one row");
if (numCols >= 0 && outputShape[1] != numCols)
return emitError("output column count must match input column count");
}
return success();
}
LogicalResult PimConcatOp::verify() {
if (getInputs().empty())
return emitError("requires at least one input");