fast pim bufferization using tensors
Validate Operations / validate-operations (push) Successful in 24m29s

This commit is contained in:
NiccoloN
2026-05-08 14:21:45 +02:00
parent 58e6587697
commit b1272d2283
7 changed files with 541 additions and 81 deletions
+38 -6
View File
@@ -48,12 +48,32 @@ static LogicalResult verifyManyCommunicationTypes(Operation* op, TypeRange types
return op->emitError() << kind << " values must all have the same type";
if (firstIsTensor != isa<RankedTensorType>(type) || firstIsMemRef != isa<MemRefType>(type))
return op->emitError() << kind << " values must all use the same shaped container kind";
if (firstShapedType.getElementType() != shapedType.getElementType() || firstShapedType.getShape() != shapedType.getShape())
if (firstShapedType.getElementType() != shapedType.getElementType()
|| firstShapedType.getShape() != shapedType.getShape())
return op->emitError() << kind << " values must all have the same shape and element type";
}
return success();
}
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
if (coreIds.empty())
return op->emitError() << kind << " must carry at least one chunk";
auto shapedType = dyn_cast<ShapedType>(type);
if (!shapedType || !shapedType.hasStaticShape())
return op->emitError() << kind << " requires a static shaped tensor or memref";
int64_t elementBits = shapedType.getElementTypeBitWidth();
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % static_cast<int64_t>(coreIds.size()) != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the number of core ids";
return success();
}
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
if (!coreBatchOp)
@@ -61,9 +81,7 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
return coreBatchOp.getLaneCount();
}
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op,
ArrayRef<int32_t> coreIds,
size_t valueCount) {
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op, ArrayRef<int32_t> coreIds, size_t valueCount) {
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside pim.core_batch");
@@ -109,7 +127,8 @@ LogicalResult PimMapOp::verify() {
Block& block = getBody().front();
if (block.getNumArguments() != 1)
return emitError("body must have exactly one block argument");
if (block.getArgument(0).getType() != inputType)
if (failed(verifyCompatibleShapedTypes(
getOperation(), block.getArgument(0).getType(), inputType, "body block argument type must match input type")))
return emitError("body block argument type must match input type");
auto yieldOp = dyn_cast_or_null<PimYieldOp>(block.getTerminator());
@@ -117,7 +136,8 @@ LogicalResult PimMapOp::verify() {
return emitError("body must terminate with pim.yield");
if (yieldOp.getNumOperands() != 1)
return emitError("body yield must produce exactly one value");
if (yieldOp.getOperand(0).getType() != outputType)
if (failed(verifyCompatibleShapedTypes(
getOperation(), yieldOp.getOperand(0).getType(), outputType, "body yield type must match output type")))
return emitError("body yield type must match output type");
return success();
@@ -129,6 +149,10 @@ LogicalResult PimSendManyOp::verify() {
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many");
}
LogicalResult PimSendTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
}
LogicalResult PimSendManyBatchOp::verify() {
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
return failure();
@@ -153,6 +177,14 @@ LogicalResult PimReceiveManyOp::verify() {
return success();
}
LogicalResult PimReceiveTensorOp::verify() {
if (failed(verifyCompatibleShapedTypes(
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
}
LogicalResult PimReceiveManyBatchOp::verify() {
if (getOutputBuffers().size() != getOutputs().size())
return emitError("number of output buffers must match the number of outputs");