#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Support/LogicalResult.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { namespace spatial { namespace { inline LogicalResult mvmOpVerifySize2(SpatMVMOp* emitter, ArrayRef& matrixShape, ArrayRef& vectorShape, ArrayRef& outputShape) { if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) return emitter->emitError("matrix, vector and output must have rank 2"); int64_t N = matrixShape[0]; int64_t M = matrixShape[1]; if (N <= 0 || M <= 0) return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0"); int64_t vectorM = vectorShape[0]; int64_t vector1 = vectorShape[1]; if (vectorM != M || vector1 != 1) return emitter->emitError("vector shape must be (M, 1)"); int64_t outputN = outputShape[0]; int64_t output1 = outputShape[1]; if (outputN != N || output1 != 1) return emitter->emitError("output shape must be (N, 1)"); return success(); } inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter, ArrayRef& matrixShape, ArrayRef& vectorShape, ArrayRef& outputShape) { if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4) return emitter->emitError("matrix, vector and output must have rank 4"); int64_t N = matrixShape[0]; int64_t M = matrixShape[1]; int64_t matrix1First = matrixShape[2]; int64_t matrix1Second = matrixShape[3]; if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1) return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0"); int64_t vector1First = vectorShape[0]; int64_t vectorM = vectorShape[1]; int64_t vector1Second = vectorShape[2]; int64_t vector1Third = vectorShape[3]; if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) { if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) { // This is ok, it was caused by the simplification of the concat error. } else { return emitter->emitError("vector shape must be (1, M, 1, 1)"); } } int64_t output1First = outputShape[0]; int64_t outputN = outputShape[1]; int64_t output1Second = outputShape[2]; int64_t output1Third = outputShape[3]; if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1) return emitter->emitError("output shape must be (1, N, 1, 1)"); return success(); } static FailureOr> getWeightShapeForWeightedOp(Value weight) { auto shapedType = dyn_cast(weight.getType()); if (!shapedType) return failure(); return shapedType.getShape(); } static FailureOr getParentBatchLaneCount(Operation* op) { auto batchOp = op->getParentOfType(); if (!batchOp) return failure(); return batchOp.getLaneCount(); } static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) { if (batchOp.getNumResults() == 0) return false; auto blockArg = dyn_cast(value); if (!blockArg || blockArg.getOwner() != &batchOp.getBody().front()) return false; unsigned argNumber = blockArg.getArgNumber(); auto firstOutputArg = batchOp.getOutputArgument(0); if (!firstOutputArg) return false; unsigned firstOutputArgNumber = firstOutputArg->getArgNumber(); return argNumber >= firstOutputArgNumber && argNumber < firstOutputArgNumber + batchOp.getNumResults(); } static bool isConstantIndexLike(Value value) { APInt constantValue; return matchPattern(value, m_ConstantInt(&constantValue)); } static bool isSupportedLaneAffineExpr(AffineExpr expr) { switch (expr.getKind()) { case AffineExprKind::Constant: case AffineExprKind::DimId: return true; case AffineExprKind::SymbolId: return false; case AffineExprKind::Add: { auto binaryExpr = cast(expr); return isSupportedLaneAffineExpr(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS()); } case AffineExprKind::Mul: { auto binaryExpr = cast(expr); return (isa(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS())) || (isa(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS())); } case AffineExprKind::FloorDiv: case AffineExprKind::CeilDiv: case AffineExprKind::Mod: { auto binaryExpr = cast(expr); return isa(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS()); } } llvm_unreachable("unexpected affine expression kind"); } static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) { if (value == laneArg || isConstantIndexLike(value)) return true; auto affineApply = value.getDefiningOp(); if (affineApply) { if (affineApply.getAffineMap().getNumResults() != 1 || affineApply.getAffineMap().getNumSymbols() != 0) return false; if (!llvm::all_of(affineApply.getMapOperands(), [&](Value operand) { return isSupportedLaneOffsetExpr(operand, laneArg); })) { return false; } return isSupportedLaneAffineExpr(affineApply.getAffineMap().getResult(0)); } auto extractOp = value.getDefiningOp(); if (extractOp) { auto constantTensor = extractOp.getTensor().getDefiningOp(); auto denseAttr = constantTensor ? dyn_cast(constantTensor.getValue()) : nullptr; if (!denseAttr || denseAttr.getType().getRank() != 1 || extractOp.getIndices().size() != 1) return false; return isSupportedLaneOffsetExpr(extractOp.getIndices().front(), laneArg); } auto addOp = value.getDefiningOp(); if (!addOp) return false; return (addOp.getLhs() == laneArg && isConstantIndexLike(addOp.getRhs())) || (addOp.getRhs() == laneArg && isConstantIndexLike(addOp.getLhs())); } static LogicalResult verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgument laneArg, StringRef kind) { auto sourceType = dyn_cast(sliceOp.getSource().getType()); auto resultType = dyn_cast(sliceOp.getResult().getType()); if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) return sliceOp.emitOpError() << kind << " requires static ranked tensor types"; if (!sliceOp.hasUnitStride()) return sliceOp.emitOpError() << kind << " requires unit strides"; for (int64_t size : sliceOp.getStaticSizes()) if (ShapedType::isDynamic(size)) return sliceOp.emitOpError() << kind << " requires static slice sizes"; auto offsets = sliceOp.getOffsets(); for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) { bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset); if (!supported) return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets"; } return success(); } static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::ParallelInsertSliceOp sliceOp, BlockArgument laneArg, StringRef kind) { RankedTensorType sourceType = sliceOp.getSourceType(); RankedTensorType destType = sliceOp.getDestType(); if (!sourceType.hasStaticShape() || !destType.hasStaticShape()) return sliceOp.emitOpError() << kind << " requires static ranked tensor types"; if (!sliceOp.hasUnitStride()) return sliceOp.emitOpError() << kind << " requires unit strides"; for (int64_t size : sliceOp.getStaticSizes()) if (ShapedType::isDynamic(size)) return sliceOp.emitOpError() << kind << " requires static slice sizes"; auto offsets = sliceOp.getOffsets(); for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) { bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset); if (!supported) return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets"; } return success(); } static LogicalResult verifyTensorChannelSizes( Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) { if (channelCount != sourceCoreCount || channelCount != targetCoreCount) return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); if (channelCount == 0) return op->emitError() << kind << " must carry at least one chunk"; auto shapedType = dyn_cast(type); if (!shapedType || !shapedType.hasStaticShape()) return op->emitError() << kind << " requires a static shaped tensor"; int64_t elementBits = shapedType.getElementTypeBitWidth(); if (elementBits <= 0 || elementBits % 8 != 0) return op->emitError() << kind << " requires byte-sized elements"; int64_t totalBytes = shapedType.getNumElements() * elementBits / 8; if (totalBytes % static_cast(channelCount) != 0) return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids"; return success(); } static LogicalResult verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) { if (channelCount != sourceCoreCount || channelCount != targetCoreCount) return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); auto laneCount = getParentBatchLaneCount(op); if (failed(laneCount)) return op->emitError("must be nested inside spat.compute_batch"); if (channelCount != static_cast(*laneCount)) return op->emitError("channel metadata length must match parent laneCount"); return success(); } static LogicalResult verifyTensorBatchChannelSizes( Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) { if (channelCount != sourceCoreCount || channelCount != targetCoreCount) return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); auto laneCount = getParentBatchLaneCount(op); if (failed(laneCount)) return op->emitError("must be nested inside spat.compute_batch"); if (channelCount == 0 || channelCount % static_cast(*laneCount) != 0) return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount"; auto shapedType = dyn_cast(type); if (!shapedType || !shapedType.hasStaticShape()) return op->emitError() << kind << " requires a static shaped tensor"; int64_t elementBits = shapedType.getElementTypeBitWidth(); if (elementBits <= 0 || elementBits % 8 != 0) return op->emitError() << kind << " requires byte-sized elements"; int64_t chunkCount = static_cast(channelCount) / *laneCount; int64_t totalBytes = shapedType.getNumElements() * elementBits / 8; if (totalBytes % chunkCount != 0) return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane"; return success(); } static Region* getParentRegion(Value value) { if (auto blockArg = dyn_cast(value)) return blockArg.getOwner()->getParent(); if (Operation* definingOp = value.getDefiningOp()) return definingOp->getParentRegion(); return nullptr; } static bool isDefinedInsideRegion(Value value, Region& region) { Region* parentRegion = getParentRegion(value); return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion)); } static bool isConstantExternalValue(Value value) { Operation* definingOp = value.getDefiningOp(); return definingOp && definingOp->hasTrait(); } static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) { bool hasFailure = false; region.walk([&](Operation* op) { for (OpOperand& operand : op->getOpOperands()) { Value value = operand.get(); if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)) continue; InFlightDiagnostic diagnostic = ownerOp->emitOpError() << kind << " body may only directly reference external constants"; diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName(); hasFailure = true; } }); return success(!hasFailure); } static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) { if (batchOp.getNumResults() == 0) { auto yieldOp = dyn_cast_or_null(block.getTerminator()); if (!yieldOp) return batchOp.emitError("resultless compute_batch body must terminate with spat.yield"); if (yieldOp.getNumOperands() != 0) return batchOp.emitError("resultless compute_batch body yield must be empty"); } else if (!isa_and_nonnull(block.getTerminator())) { return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel"); } auto laneArg = batchOp.getLaneArgument(); if (!laneArg) return batchOp.emitError("compute_batch body must have a lane block argument"); for (auto& bodyOp : block) { if (auto extractSlice = dyn_cast(&bodyOp)) if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, *laneArg, "tensor.extract_slice"))) return failure(); } return success(); } } // namespace LogicalResult SpatMVMOp::verify() { auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight()); if (failed(matrixShapeOpt)) return emitError("weight must be a shaped value"); auto matrixShape = *matrixShapeOpt; auto vectorShape = getInput().getType().getShape(); auto outputShape = getOutput().getType().getShape(); if (matrixShape.size() == 2) return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape); if (matrixShape.size() == 4) return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape); return emitError("matrix rank must be 2 or 4"); } LogicalResult SpatVMMOp::verify() { auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight()); if (failed(matrixShapeOpt)) return emitError("weight must be a shaped value"); auto matrixShape = *matrixShapeOpt; auto vectorShape = getInput().getType().getShape(); auto outputShape = getOutput().getType().getShape(); if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) return emitError("matrix, vector and output must have rank 2"); int64_t N = matrixShape[0]; int64_t M = matrixShape[1]; if (N <= 0 || M <= 0) return emitError("matrix shape must be (N, M) with N > 0 and M > 0"); int64_t vector1 = vectorShape[0]; int64_t vectorN = vectorShape[1]; if (vectorN != N || vector1 != 1) return emitError("vector shape must be (1, N)"); int64_t output1 = outputShape[0]; int64_t outputM = outputShape[1]; if (outputM != M || output1 != 1) return emitError("output shape must be (1, M)"); return success(); } LogicalResult SpatVAddOp::verify() { if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) return failure(); return OpTrait::impl::verifySameOperandsAndResultType(*this); } LogicalResult SpatVMaxOp::verify() { if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) return failure(); return OpTrait::impl::verifySameOperandsAndResultType(*this); } LogicalResult SpatExtractRowsOp::verify() { auto inputType = dyn_cast(getInput().getType()); if (!inputType || !inputType.hasRank() || inputType.getRank() != 2) return emitError("input must be a rank-2 shaped type"); int64_t numRows = inputType.getShape()[0]; int64_t numCols = inputType.getShape()[1]; Type elementType = inputType.getElementType(); if (numRows >= 0 && static_cast(getNumResults()) != numRows) return emitError("number of outputs must match the number of input rows"); for (Type output : getResultTypes()) { auto outputType = dyn_cast(output); if (!outputType || !outputType.hasRank() || outputType.getRank() != 2) return emitError("outputs must all be rank-2 shaped types"); if (outputType.getElementType() != elementType) return emitError("output element types must match input element type"); auto outputShape = outputType.getShape(); if (outputShape[0] != 1) return emitError("each output must have exactly one row"); if (numCols >= 0 && outputShape[1] != numCols) return emitError("output column count must match input column count"); } return success(); } LogicalResult SpatConcatOp::verify() { if (getInputs().empty()) return emitError("requires at least one input"); auto outputType = dyn_cast(getOutput().getType()); if (!outputType || !outputType.hasRank()) return emitError("output must be a ranked shaped type"); int64_t axis = getAxis(); int64_t rank = outputType.getRank(); if (axis < 0 || axis >= rank) return emitError("axis must be within the output rank"); int64_t concatenatedDimSize = 0; bool concatenatedDimDynamic = false; Type outputElementType = outputType.getElementType(); for (Value input : getInputs()) { auto inputType = dyn_cast(input.getType()); if (!inputType || !inputType.hasRank()) return emitError("inputs must be ranked shaped types"); if (inputType.getRank() != rank) return emitError("all inputs must have the same rank as the output"); if (inputType.getElementType() != outputElementType) return emitError("all inputs must have the same element type as the output"); for (int64_t dim = 0; dim < rank; ++dim) { if (dim == axis) continue; int64_t inputDim = inputType.getDimSize(dim); int64_t outputDim = outputType.getDimSize(dim); if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim) return emitError("non-concatenated dimensions must match the output shape"); } int64_t inputConcatDim = inputType.getDimSize(axis); if (ShapedType::isDynamic(inputConcatDim)) { concatenatedDimDynamic = true; continue; } concatenatedDimSize += inputConcatDim; } int64_t outputConcatDim = outputType.getDimSize(axis); if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim) return emitError("output concatenated dimension must equal the sum of input sizes"); return success(); } LogicalResult verifyComputeResultsUses(Operation* op) { if (!isa(op)) return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation"); if (!llvm::all_of(op->getResults(), [](Value result) { return llvm::all_of(result.getUsers(), [](Operation* op) { return !(op->getParentOfType() || op->getParentOfType()); }); })) { return op->emitError("ComputeResult used directly inside another Compute"); } return success(); } LogicalResult SpatCompute::verify() { auto& block = getBody().front(); unsigned expectedArgCount = getWeights().size() + getInputs().size(); if (block.getNumArguments() != expectedArgCount) return emitError("compute body must have weight and input block arguments"); for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { auto blockArg = getWeightArgument(weightIndex); if (!blockArg || blockArg->getType() != weight.getType()) return emitError("compute weight block argument types must match weight operand types exactly"); } for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { auto blockArg = getInputArgument(inputIndex); if (!blockArg || blockArg->getType() != input.getType()) return emitError("compute input block argument types must match input operand types exactly"); } if (block.mightHaveTerminator()) { auto yieldOp = dyn_cast_or_null(block.getTerminator()); if (!yieldOp) return emitError("ComputeOp must have a single yield operation"); auto resultTypes = getResultTypes(); auto yieldTypes = yieldOp->getOperandTypes(); if (resultTypes.size() != yieldTypes.size()) return emitError("ComputeOp must have same number of results as yieldOp operands"); for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) { auto resultType = std::get<0>(it); auto yieldType = std::get<1>(it); if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) return emitError("ComputeOp output must be of the same type as yieldOp operand"); if (auto resultRankedType = dyn_cast(resultType)) { if (auto yieldRankedType = dyn_cast(yieldType)) { if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) return emitError("ComputeOp output must have the same encoding as yieldOp operand"); } else { return emitError("ComputeOp output has an encoding while yieldOp operand does not have one"); } } else if (dyn_cast(yieldType)) { return emitError("ComputeOp output must not have an encoding if yieldOp operand has one"); } } } for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex) if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty()) return emitError("ComputeOp block argument is not used"); if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute"))) return failure(); if (failed(verifyComputeResultsUses(this->getOperation()))) return failure(); return success(); } LogicalResult SpatChannelSendTensorOp::verify() { return verifyTensorChannelSizes(getOperation(), getInput().getType(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size(), "channel_send_tensor"); } LogicalResult SpatChannelReceiveTensorOp::verify() { return verifyTensorChannelSizes(getOperation(), getOutput().getType(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size(), "channel_receive_tensor"); } LogicalResult SpatChannelSendBatchOp::verify() { return verifyBatchChannelSizes( getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size()); } LogicalResult SpatChannelSendTensorBatchOp::verify() { return verifyTensorBatchChannelSizes(getOperation(), getInput().getType(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size(), "channel_send_tensor_batch"); } LogicalResult SpatChannelReceiveBatchOp::verify() { return verifyBatchChannelSizes( getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size()); } LogicalResult SpatChannelReceiveTensorBatchOp::verify() { return verifyTensorBatchChannelSizes(getOperation(), getOutput().getType(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size(), "channel_receive_tensor_batch"); } LogicalResult SpatComputeBatch::verify() { int32_t count = getLaneCount(); if (count <= 0) return emitError("laneCount must be positive"); auto laneCountSz = static_cast(count); if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) { auto coreIdsAttr = dyn_cast(coreIdAttr); if (!coreIdsAttr) return emitError("compute_batch coreIds attribute must be a dense i32 array"); if (coreIdsAttr.size() != static_cast(laneCountSz)) return emitError("compute_batch coreIds array length must match laneCount"); if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; })) return emitError("compute_batch coreIds values must be non-negative"); DenseSet seenCoreIds; for (int32_t coreId : coreIdsAttr.asArrayRef()) if (!seenCoreIds.insert(coreId).second) return emitError("compute_batch coreIds values must be unique"); } Block& block = getBody().front(); if (block.getNumArguments() == 0) return emitError("compute_batch body must have exactly one lane block argument"); unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults(); if (block.getNumArguments() != expectedArgCount) return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results"); auto laneArg = getLaneArgument(); if (!laneArg || !laneArg->getType().isIndex()) return emitError("compute_batch first block argument must have index type"); for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) { auto blockArg = getWeightArgument(weightIndex); if (!blockArg || blockArg->getType() != weight.getType()) return emitError("compute_batch weight block argument types must match weight operand types exactly"); } for (auto [inputIndex, input] : llvm::enumerate(getInputs())) { auto blockArg = getInputArgument(inputIndex); if (!blockArg || blockArg->getType() != input.getType()) return emitError("compute_batch input block argument types must match input operand types exactly"); } for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) { auto blockArg = getOutputArgument(resultIndex); if (!blockArg || blockArg->getType() != resultType) return emitError("compute_batch output block argument types must match result types exactly"); } if (failed(verifyComputeResultsUses(this->getOperation()))) return failure(); if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch"))) return failure(); return verifyBatchBody(*this, block); } LogicalResult SpatInParallelOp::verify() { auto batchOp = getOperation()->getParentOfType(); if (!batchOp) return emitOpError("expected spat.compute_batch parent"); if (batchOp.getNumResults() == 0) return emitOpError("requires a resultful spat.compute_batch parent"); auto laneArg = batchOp.getLaneArgument(); if (!laneArg) return emitOpError("expected compute_batch lane block argument"); for (Operation& op : getRegion().front().getOperations()) { auto insertSliceOp = dyn_cast(&op); if (!insertSliceOp) return emitOpError("expected only tensor.parallel_insert_slice ops"); if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, *laneArg, "tensor.parallel_insert_slice"))) return failure(); MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations(); for (OpOperand& destination : destinations) if (!isBatchOutputArgument(batchOp, destination.get())) return op.emitOpError("may only insert into a compute_batch output block argument"); } return success(); } } // namespace spatial } // namespace onnx_mlir