huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
This commit is contained in:
@@ -105,26 +105,28 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
return batchOp.getLaneCount();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyChannelSizes(Operation* op,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
size_t valueCount) {
|
||||
static LogicalResult verifyTensorChannelSizes(Operation* op,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
StringRef kind) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
if (channelIds.size() != valueCount)
|
||||
return op->emitError("channel metadata length must match the number of values");
|
||||
return success();
|
||||
}
|
||||
if (channelIds.empty())
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
|
||||
static LogicalResult verifyManyChannelTypes(Operation* op, TypeRange types, StringRef kind) {
|
||||
if (types.empty())
|
||||
return op->emitError() << kind << " must carry at least one value";
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return op->emitError() << kind << " requires a static shaped tensor";
|
||||
|
||||
Type firstType = types.front();
|
||||
for (Type type : types.drop_front())
|
||||
if (type != firstType)
|
||||
return op->emitError() << kind << " values must all have the same type";
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % static_cast<int64_t>(channelIds.size()) != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -144,19 +146,33 @@ static LogicalResult verifyBatchChannelSizes(Operation* op,
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyBatchChannelSizes(Operation* op,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
size_t valueCount) {
|
||||
static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
StringRef kind) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelIds.size() != valueCount * static_cast<size_t>(*laneCount))
|
||||
return op->emitError("channel metadata length must match the number of values times parent laneCount");
|
||||
if (channelIds.empty() || channelIds.size() % static_cast<size_t>(*laneCount) != 0)
|
||||
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
return op->emitError() << kind << " requires a static shaped tensor";
|
||||
|
||||
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t chunkCount = static_cast<int64_t>(channelIds.size()) / *laneCount;
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % chunkCount != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -323,39 +339,6 @@ LogicalResult SpatConcatOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatMapOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
if (getOutputs().size() != getInputs().size())
|
||||
return emitError("number of outputs must match number of inputs");
|
||||
|
||||
Type inputType = getInputs().front().getType();
|
||||
for (Value input : getInputs().drop_front())
|
||||
if (input.getType() != inputType)
|
||||
return emitError("all inputs must have the same type");
|
||||
|
||||
Type outputType = getOutputs().front().getType();
|
||||
for (Value output : getOutputs().drop_front())
|
||||
if (output.getType() != outputType)
|
||||
return emitError("all outputs must have the same type");
|
||||
|
||||
Block& block = getBody().front();
|
||||
if (block.getNumArguments() != 1)
|
||||
return emitError("body must have exactly one block argument");
|
||||
if (block.getArgument(0).getType() != inputType)
|
||||
return emitError("body block argument type must match input type");
|
||||
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return emitError("body must terminate with spat.yield");
|
||||
if (yieldOp.getNumOperands() != 1)
|
||||
return emitError("body yield must produce exactly one value");
|
||||
if (yieldOp.getOperand(0).getType() != outputType)
|
||||
return emitError("body yield type must match output type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatCompute::verify() {
|
||||
auto& block = getBody().front();
|
||||
if (block.mightHaveTerminator()) {
|
||||
@@ -397,40 +380,48 @@ LogicalResult SpatCompute::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendManyOp::verify() {
|
||||
if (failed(verifyManyChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many");
|
||||
LogicalResult SpatChannelSendTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
"channel_send_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveManyOp::verify() {
|
||||
if (failed(verifyManyChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many");
|
||||
LogicalResult SpatChannelReceiveTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
"channel_receive_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendManyBatchOp::verify() {
|
||||
if (failed(verifyManyBatchChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many_batch");
|
||||
LogicalResult SpatChannelSendTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
"channel_send_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveManyBatchOp::verify() {
|
||||
if (failed(verifyManyBatchChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many_batch");
|
||||
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
"channel_receive_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatComputeBatch::verify() {
|
||||
|
||||
Reference in New Issue
Block a user