better spatial IR compaction with better custom syntax, scf.for and
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled
spat.map
This commit is contained in:
@@ -83,13 +83,13 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
|
||||
if (auto computeOp = dyn_cast<SpatCompute>(weightedOp->getParentOp()))
|
||||
if (auto computeOp = weightedOp->getParentOfType<SpatCompute>())
|
||||
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(weightedOp->getParentOp()))
|
||||
if (auto coreOp = weightedOp->getParentOfType<pim::PimCoreOp>())
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
if (auto batchOp = dyn_cast<SpatComputeBatch>(weightedOp->getParentOp())) {
|
||||
if (auto batchOp = weightedOp->getParentOfType<SpatComputeBatch>()) {
|
||||
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
|
||||
@@ -144,6 +144,23 @@ static LogicalResult verifyBatchChannelSizes(Operation* op,
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyBatchChannelSizes(Operation* op,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
size_t valueCount) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelIds.size() != valueCount * static_cast<size_t>(*laneCount))
|
||||
return op->emitError("channel metadata length must match the number of values times parent laneCount");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
@@ -306,6 +323,39 @@ LogicalResult SpatConcatOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatMapOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
if (getOutputs().size() != getInputs().size())
|
||||
return emitError("number of outputs must match number of inputs");
|
||||
|
||||
Type inputType = getInputs().front().getType();
|
||||
for (Value input : getInputs().drop_front())
|
||||
if (input.getType() != inputType)
|
||||
return emitError("all inputs must have the same type");
|
||||
|
||||
Type outputType = getOutputs().front().getType();
|
||||
for (Value output : getOutputs().drop_front())
|
||||
if (output.getType() != outputType)
|
||||
return emitError("all outputs must have the same type");
|
||||
|
||||
Block& block = getBody().front();
|
||||
if (block.getNumArguments() != 1)
|
||||
return emitError("body must have exactly one block argument");
|
||||
if (block.getArgument(0).getType() != inputType)
|
||||
return emitError("body block argument type must match input type");
|
||||
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return emitError("body must terminate with spat.yield");
|
||||
if (yieldOp.getNumOperands() != 1)
|
||||
return emitError("body yield must produce exactly one value");
|
||||
if (yieldOp.getOperand(0).getType() != outputType)
|
||||
return emitError("body yield type must match output type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatCompute::verify() {
|
||||
auto& block = getBody().front();
|
||||
if (block.mightHaveTerminator()) {
|
||||
@@ -365,10 +415,24 @@ LogicalResult SpatChannelSendBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendManyBatchOp::verify() {
|
||||
if (failed(verifyManyBatchChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveManyBatchOp::verify() {
|
||||
if (failed(verifyManyBatchChannelSizes(
|
||||
getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size())))
|
||||
return failure();
|
||||
return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatComputeBatch::verify() {
|
||||
int32_t count = getLaneCount();
|
||||
if (count <= 0)
|
||||
@@ -405,18 +469,18 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
return emitError("all outputs must have the same type");
|
||||
}
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdAttrName)) {
|
||||
if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdsAttrName)) {
|
||||
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
||||
if (!coreIdsAttr)
|
||||
return emitError("compute_batch core_id attribute must be a dense i32 array");
|
||||
return emitError("compute_batch coreIds attribute must be a dense i32 array");
|
||||
if (coreIdsAttr.size() != laneCountSz)
|
||||
return emitError("compute_batch core_id array length must match laneCount");
|
||||
return emitError("compute_batch coreIds array length must match laneCount");
|
||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
|
||||
return emitError("compute_batch core_id values must be positive");
|
||||
return emitError("compute_batch coreIds values must be positive");
|
||||
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
|
||||
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
||||
if (!seenCoreIds.insert(coreId).second)
|
||||
return emitError("compute_batch core_id values must be distinct");
|
||||
return emitError("compute_batch coreIds values must be distinct");
|
||||
}
|
||||
|
||||
Block& block = getBody().front();
|
||||
|
||||
Reference in New Issue
Block a user