huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
@@ -46,12 +46,47 @@ static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRe
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreIds, StringRef kind) {
|
||||
if (coreIds.empty())
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
|
||||
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
|
||||
if (!coreBatchOp)
|
||||
return op->emitError() << kind << " must be nested inside pim.core_batch";
|
||||
|
||||
int32_t laneCount = coreBatchOp.getLaneCount();
|
||||
if (laneCount <= 0)
|
||||
return op->emitError() << kind << " requires a positive parent laneCount";
|
||||
if (coreIds.size() % static_cast<size_t>(laneCount) != 0)
|
||||
return op->emitError() << kind << " core id count must be divisible by the parent laneCount";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return op->emitError() << kind << " requires a static shaped tensor or memref";
|
||||
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t chunkCount = static_cast<int64_t>(coreIds.size()) / laneCount;
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % chunkCount != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult PimSendTensorOp::verify() {
|
||||
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
|
||||
}
|
||||
|
||||
LogicalResult PimSendTensorBatchOp::verify() {
|
||||
return verifyTensorBatchCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveTensorOp::verify() {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
@@ -60,6 +95,15 @@ LogicalResult PimReceiveTensorOp::verify() {
|
||||
return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor");
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveTensorBatchOp::verify() {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
return failure();
|
||||
|
||||
return verifyTensorBatchCommunication(
|
||||
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimConcatOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
|
||||
Reference in New Issue
Block a user