fast pim bufferization using tensors
Validate Operations / validate-operations (push) Successful in 24m29s
Validate Operations / validate-operations (push) Successful in 24m29s
This commit is contained in:
@@ -48,12 +48,32 @@ static LogicalResult verifyManyCommunicationTypes(Operation* op, TypeRange types
|
||||
return op->emitError() << kind << " values must all have the same type";
|
||||
if (firstIsTensor != isa<RankedTensorType>(type) || firstIsMemRef != isa<MemRefType>(type))
|
||||
return op->emitError() << kind << " values must all use the same shaped container kind";
|
||||
if (firstShapedType.getElementType() != shapedType.getElementType() || firstShapedType.getShape() != shapedType.getShape())
|
||||
if (firstShapedType.getElementType() != shapedType.getElementType()
|
||||
|| firstShapedType.getShape() != shapedType.getShape())
|
||||
return op->emitError() << kind << " values must all have the same shape and element type";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
|
||||
if (coreIds.empty())
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return op->emitError() << kind << " requires a static shaped tensor or memref";
|
||||
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % static_cast<int64_t>(coreIds.size()) != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the number of core ids";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
|
||||
if (!coreBatchOp)
|
||||
@@ -61,9 +81,7 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
return coreBatchOp.getLaneCount();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op,
|
||||
ArrayRef<int32_t> coreIds,
|
||||
size_t valueCount) {
|
||||
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op, ArrayRef<int32_t> coreIds, size_t valueCount) {
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside pim.core_batch");
|
||||
@@ -109,7 +127,8 @@ LogicalResult PimMapOp::verify() {
|
||||
Block& block = getBody().front();
|
||||
if (block.getNumArguments() != 1)
|
||||
return emitError("body must have exactly one block argument");
|
||||
if (block.getArgument(0).getType() != inputType)
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), block.getArgument(0).getType(), inputType, "body block argument type must match input type")))
|
||||
return emitError("body block argument type must match input type");
|
||||
|
||||
auto yieldOp = dyn_cast_or_null<PimYieldOp>(block.getTerminator());
|
||||
@@ -117,7 +136,8 @@ LogicalResult PimMapOp::verify() {
|
||||
return emitError("body must terminate with pim.yield");
|
||||
if (yieldOp.getNumOperands() != 1)
|
||||
return emitError("body yield must produce exactly one value");
|
||||
if (yieldOp.getOperand(0).getType() != outputType)
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), yieldOp.getOperand(0).getType(), outputType, "body yield type must match output type")))
|
||||
return emitError("body yield type must match output type");
|
||||
|
||||
return success();
|
||||
@@ -129,6 +149,10 @@ LogicalResult PimSendManyOp::verify() {
|
||||
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many");
|
||||
}
|
||||
|
||||
LogicalResult PimSendTensorOp::verify() {
|
||||
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
|
||||
}
|
||||
|
||||
LogicalResult PimSendManyBatchOp::verify() {
|
||||
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
|
||||
return failure();
|
||||
@@ -153,6 +177,14 @@ LogicalResult PimReceiveManyOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveTensorOp::verify() {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
return failure();
|
||||
|
||||
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveManyBatchOp::verify() {
|
||||
if (getOutputBuffers().size() != getOutputs().size())
|
||||
return emitError("number of output buffers must match the number of outputs");
|
||||
|
||||
Reference in New Issue
Block a user