#include "mlir/IR/Block.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Support/LogicalResult.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { namespace spatial { namespace { inline LogicalResult mvmOpVerifySize2(SpatMVMOp* emitter, ArrayRef& matrixShape, ArrayRef& vectorShape, ArrayRef& outputShape) { if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) return emitter->emitError("matrix, vector and output must have rank 2"); int64_t N = matrixShape[0]; int64_t M = matrixShape[1]; if (N <= 0 || M <= 0) return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0"); int64_t vectorM = vectorShape[0]; int64_t vector1 = vectorShape[1]; if (vectorM != M || vector1 != 1) return emitter->emitError("vector shape must be (M, 1)"); int64_t outputN = outputShape[0]; int64_t output1 = outputShape[1]; if (outputN != N || output1 != 1) return emitter->emitError("output shape must be (N, 1)"); return success(); } inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter, ArrayRef& matrixShape, ArrayRef& vectorShape, ArrayRef& outputShape) { if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4) return emitter->emitError("matrix, vector and output must have rank 4"); int64_t N = matrixShape[0]; int64_t M = matrixShape[1]; int64_t matrix1First = matrixShape[2]; int64_t matrix1Second = matrixShape[3]; if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1) return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0"); int64_t vector1First = vectorShape[0]; int64_t vectorM = vectorShape[1]; int64_t vector1Second = vectorShape[2]; int64_t vector1Third = vectorShape[3]; if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) { if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) { // This is ok, it was caused by the simplification of the concat error. } else { return emitter->emitError("vector shape must be (1, M, 1, 1)"); } } int64_t output1First = outputShape[0]; int64_t outputN = outputShape[1]; int64_t output1Second = outputShape[2]; int64_t output1Third = outputShape[3]; if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1) return emitter->emitError("output shape must be (1, N, 1, 1)"); return success(); } static FailureOr> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) { if (auto computeOp = weightedOp->getParentOfType()) return cast(computeOp.getWeights()[weightIndex].getType()).getShape(); if (auto coreOp = weightedOp->getParentOfType()) return cast(coreOp.getWeights()[weightIndex].getType()).getShape(); if (auto batchOp = weightedOp->getParentOfType()) { if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size()) return failure(); return cast(batchOp.getWeights()[weightIndex].getType()).getShape(); } return failure(); } static FailureOr getParentBatchLaneCount(Operation* op) { auto batchOp = op->getParentOfType(); if (!batchOp) return failure(); return batchOp.getLaneCount(); } static LogicalResult verifyManyChannelSizes(Operation* op, ArrayRef channelIds, ArrayRef sourceCoreIds, ArrayRef targetCoreIds, size_t valueCount) { if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); if (channelIds.size() != valueCount) return op->emitError("channel metadata length must match the number of values"); return success(); } static LogicalResult verifyManyChannelTypes(Operation* op, TypeRange types, StringRef kind) { if (types.empty()) return op->emitError() << kind << " must carry at least one value"; Type firstType = types.front(); for (Type type : types.drop_front()) if (type != firstType) return op->emitError() << kind << " values must all have the same type"; return success(); } static LogicalResult verifyBatchChannelSizes(Operation* op, ArrayRef channelIds, ArrayRef sourceCoreIds, ArrayRef targetCoreIds) { if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); auto laneCount = getParentBatchLaneCount(op); if (failed(laneCount)) return op->emitError("must be nested inside spat.compute_batch"); if (channelIds.size() != static_cast(*laneCount)) return op->emitError("channel metadata length must match parent laneCount"); return success(); } static LogicalResult verifyManyBatchChannelSizes(Operation* op, ArrayRef channelIds, ArrayRef sourceCoreIds, ArrayRef targetCoreIds, size_t valueCount) { if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); auto laneCount = getParentBatchLaneCount(op); if (failed(laneCount)) return op->emitError("must be nested inside spat.compute_batch"); if (channelIds.size() != valueCount * static_cast(*laneCount)) return op->emitError("channel metadata length must match the number of values times parent laneCount"); return success(); } static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) { auto yieldOp = dyn_cast_or_null(block.getTerminator()); if (!yieldOp) return op->emitError("body must terminate with spat.yield"); if (outputTypes.empty()) { if (yieldOp.getNumOperands() != 0) return op->emitError("body yield must be empty when compute_batch has no results"); } else { if (yieldOp.getNumOperands() != 1) return op->emitError("body yield must produce exactly one value"); if (yieldOp.getOperand(0).getType() != outputTypes[0]) return op->emitError("body yield type must match output type"); } for (auto& bodyOp : block) { if (auto wvmm = dyn_cast(&bodyOp)) if (wvmm.getWeightIndex() < 0 || static_cast(wvmm.getWeightIndex()) >= weightsPerLane) return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane"); if (auto wmvm = dyn_cast(&bodyOp)) if (wmvm.getWeightIndex() < 0 || static_cast(wmvm.getWeightIndex()) >= weightsPerLane) return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane"); } return success(); } } // namespace LogicalResult SpatMVMOp::verify() { auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); if (failed(matrixShapeOpt)) return emitError("SpatMVMOp was not within a SpatCompute or Core op"); auto matrixShape = *matrixShapeOpt; auto vectorShape = getInput().getType().getShape(); auto outputShape = getOutput().getType().getShape(); if (matrixShape.size() == 2) return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape); if (matrixShape.size() == 4) return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape); return emitError("matrix rank must be 2 or 4"); } LogicalResult SpatVMMOp::verify() { auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); if (failed(matrixShapeOpt)) return emitError("SpatVMMOp was not within a SpatCompute or Core op"); auto matrixShape = *matrixShapeOpt; auto vectorShape = getInput().getType().getShape(); auto outputShape = getOutput().getType().getShape(); if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) return emitError("matrix, vector and output must have rank 2"); int64_t N = matrixShape[0]; int64_t M = matrixShape[1]; if (N <= 0 || M <= 0) return emitError("matrix shape must be (N, M) with N > 0 and M > 0"); int64_t vector1 = vectorShape[0]; int64_t vectorN = vectorShape[1]; if (vectorN != N || vector1 != 1) return emitError("vector shape must be (1, N)"); int64_t output1 = outputShape[0]; int64_t outputM = outputShape[1]; if (outputM != M || output1 != 1) return emitError("output shape must be (1, M)"); return success(); } LogicalResult SpatVAddOp::verify() { if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) return failure(); return OpTrait::impl::verifySameOperandsAndResultType(*this); } LogicalResult SpatVMaxOp::verify() { if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) return failure(); return OpTrait::impl::verifySameOperandsAndResultType(*this); } LogicalResult SpatExtractRowsOp::verify() { auto inputType = dyn_cast(getInput().getType()); if (!inputType || !inputType.hasRank() || inputType.getRank() != 2) return emitError("input must be a rank-2 shaped type"); int64_t numRows = inputType.getShape()[0]; int64_t numCols = inputType.getShape()[1]; Type elementType = inputType.getElementType(); if (numRows >= 0 && static_cast(getNumResults()) != numRows) return emitError("number of outputs must match the number of input rows"); for (Type output : getResultTypes()) { auto outputType = dyn_cast(output); if (!outputType || !outputType.hasRank() || outputType.getRank() != 2) return emitError("outputs must all be rank-2 shaped types"); if (outputType.getElementType() != elementType) return emitError("output element types must match input element type"); auto outputShape = outputType.getShape(); if (outputShape[0] != 1) return emitError("each output must have exactly one row"); if (numCols >= 0 && outputShape[1] != numCols) return emitError("output column count must match input column count"); } return success(); } LogicalResult SpatConcatOp::verify() { if (getInputs().empty()) return emitError("requires at least one input"); auto outputType = dyn_cast(getOutput().getType()); if (!outputType || !outputType.hasRank()) return emitError("output must be a ranked shaped type"); int64_t axis = getAxis(); int64_t rank = outputType.getRank(); if (axis < 0 || axis >= rank) return emitError("axis must be within the output rank"); int64_t concatenatedDimSize = 0; bool concatenatedDimDynamic = false; Type outputElementType = outputType.getElementType(); for (Value input : getInputs()) { auto inputType = dyn_cast(input.getType()); if (!inputType || !inputType.hasRank()) return emitError("inputs must be ranked shaped types"); if (inputType.getRank() != rank) return emitError("all inputs must have the same rank as the output"); if (inputType.getElementType() != outputElementType) return emitError("all inputs must have the same element type as the output"); for (int64_t dim = 0; dim < rank; ++dim) { if (dim == axis) continue; int64_t inputDim = inputType.getDimSize(dim); int64_t outputDim = outputType.getDimSize(dim); if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim) return emitError("non-concatenated dimensions must match the output shape"); } int64_t inputConcatDim = inputType.getDimSize(axis); if (ShapedType::isDynamic(inputConcatDim)) { concatenatedDimDynamic = true; continue; } concatenatedDimSize += inputConcatDim; } int64_t outputConcatDim = outputType.getDimSize(axis); if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim) return emitError("output concatenated dimension must equal the sum of input sizes"); return success(); } LogicalResult SpatMapOp::verify() { if (getInputs().empty()) return emitError("requires at least one input"); if (getOutputs().size() != getInputs().size()) return emitError("number of outputs must match number of inputs"); Type inputType = getInputs().front().getType(); for (Value input : getInputs().drop_front()) if (input.getType() != inputType) return emitError("all inputs must have the same type"); Type outputType = getOutputs().front().getType(); for (Value output : getOutputs().drop_front()) if (output.getType() != outputType) return emitError("all outputs must have the same type"); Block& block = getBody().front(); if (block.getNumArguments() != 1) return emitError("body must have exactly one block argument"); if (block.getArgument(0).getType() != inputType) return emitError("body block argument type must match input type"); auto yieldOp = dyn_cast_or_null(block.getTerminator()); if (!yieldOp) return emitError("body must terminate with spat.yield"); if (yieldOp.getNumOperands() != 1) return emitError("body yield must produce exactly one value"); if (yieldOp.getOperand(0).getType() != outputType) return emitError("body yield type must match output type"); return success(); } LogicalResult SpatCompute::verify() { auto& block = getBody().front(); if (block.mightHaveTerminator()) { auto yieldOp = dyn_cast_or_null(block.getTerminator()); if (!yieldOp) return emitError("ComputeOp must have a single yield operation"); auto resultTypes = getResultTypes(); auto yieldTypes = yieldOp->getOperandTypes(); if (resultTypes.size() != yieldTypes.size()) return emitError("ComputeOp must have same number of results as yieldOp operands"); for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) { auto resultType = std::get<0>(it); auto yieldType = std::get<1>(it); if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) return emitError("ComputeOp output must be of the same type as yieldOp operand"); if (auto resultRankedType = dyn_cast(resultType)) { if (auto yieldRankedType = dyn_cast(yieldType)) { if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) return emitError("ComputeOp output must have the same encoding as yieldOp operand"); } else { return emitError("ComputeOp output has an encoding while yieldOp operand does not have one"); } } else if (dyn_cast(yieldType)) { return emitError("ComputeOp output must not have an encoding if yieldOp operand has one"); } } } for (auto arg : block.getArguments()) if (arg.use_empty()) return emitError("ComputeOp block argument is not used"); return success(); } LogicalResult SpatChannelSendManyOp::verify() { if (failed(verifyManyChannelSizes( getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size()))) return failure(); return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many"); } LogicalResult SpatChannelReceiveManyOp::verify() { if (failed(verifyManyChannelSizes( getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size()))) return failure(); return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many"); } LogicalResult SpatChannelSendBatchOp::verify() { return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); } LogicalResult SpatChannelSendManyBatchOp::verify() { if (failed(verifyManyBatchChannelSizes( getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size()))) return failure(); return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many_batch"); } LogicalResult SpatChannelReceiveBatchOp::verify() { return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); } LogicalResult SpatChannelReceiveManyBatchOp::verify() { if (failed(verifyManyBatchChannelSizes( getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size()))) return failure(); return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many_batch"); } LogicalResult SpatComputeBatch::verify() { int32_t count = getLaneCount(); if (count <= 0) return emitError("laneCount must be positive"); auto laneCountSz = static_cast(count); if (getWeights().size() % laneCountSz != 0) return emitError("number of weights must be a multiple of laneCount"); if (!getInputs().empty() && getInputs().size() != laneCountSz) return emitError("number of inputs must be either 0 or laneCount"); if (!getOutputs().empty() && getOutputs().size() != laneCountSz) return emitError("number of outputs must be either 0 or laneCount"); size_t weightsPerLane = getWeights().size() / laneCountSz; for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) { Type weightType = getWeights()[weightIndex].getType(); for (size_t lane = 1; lane < laneCountSz; ++lane) if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType) return emitError("corresponding weights across lanes must have the same type"); } if (!getInputs().empty()) { Type inputType = getInputs()[0].getType(); for (Value in : getInputs().drop_front()) if (in.getType() != inputType) return emitError("all inputs must have the same type"); } if (!getOutputs().empty()) { Type outputType = getOutputs()[0].getType(); for (Value out : getOutputs().drop_front()) if (out.getType() != outputType) return emitError("all outputs must have the same type"); } if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdsAttrName)) { auto coreIdsAttr = dyn_cast(coreIdAttr); if (!coreIdsAttr) return emitError("compute_batch coreIds attribute must be a dense i32 array"); if (coreIdsAttr.size() != laneCountSz) return emitError("compute_batch coreIds array length must match laneCount"); if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; })) return emitError("compute_batch coreIds values must be positive"); llvm::SmallDenseSet seenCoreIds; for (int32_t coreId : coreIdsAttr.asArrayRef()) if (!seenCoreIds.insert(coreId).second) return emitError("compute_batch coreIds values must be distinct"); } Block& block = getBody().front(); if (getInputs().empty()) { if (block.getNumArguments() != 0) return emitError("compute_batch body must have no block arguments when there are no inputs"); } else { if (block.getNumArguments() != 1) return emitError("compute_batch body must have exactly one block argument"); if (block.getArgument(0).getType() != getInputs()[0].getType()) return emitError("body block argument type must match input type"); } return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane); } } // namespace spatial } // namespace onnx_mlir