#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Block.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/Support/LogicalResult.h" #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; namespace onnx_mlir { namespace pim { namespace { static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { if (isa(op)) return operandIndex == 3; if (isa(op)) return operandIndex == 1; if (isa(op)) return operandIndex == 2; return false; } static Region* getParentRegion(Value value) { if (auto blockArgument = dyn_cast(value)) return blockArgument.getParentRegion(); Operation* definingOp = value.getDefiningOp(); return definingOp ? definingOp->getParentRegion() : nullptr; } static bool isDefinedInsideRegion(Value value, Region& region) { Region* parentRegion = getParentRegion(value); return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion)); } static bool isConstantExternalValue(Value value) { Operation* definingOp = value.getDefiningOp(); if (!definingOp) return false; if (definingOp->hasTrait()) return true; auto getGlobalOp = dyn_cast(definingOp); if (!getGlobalOp) return false; auto moduleOp = definingOp->getParentOfType(); auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); return globalOp && globalOp.getConstant(); } static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) { bool hasFailure = false; region.walk([&](Operation* op) { for (OpOperand& operand : op->getOpOperands()) { Value value = operand.get(); if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value) || isExplicitHostOperand(op, operand.getOperandNumber())) continue; InFlightDiagnostic diagnostic = ownerOp->emitOpError() << kind << " body may only directly reference external constants"; diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName(); hasFailure = true; } }); return success(!hasFailure); } static bool haveSameShapedContainerKind(Type lhs, Type rhs) { return (isa(lhs) && isa(rhs)) || (isa(lhs) && isa(rhs)); } static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type rhs, StringRef message) { auto lhsShaped = dyn_cast(lhs); auto rhsShaped = dyn_cast(rhs); if (!lhsShaped || !rhsShaped || !haveSameShapedContainerKind(lhs, rhs)) return op->emitError(message); if (lhsShaped.getElementType() != rhsShaped.getElementType() || lhsShaped.getShape() != rhsShaped.getShape()) return op->emitError(message); return success(); } static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRef coreIds, StringRef kind) { if (coreIds.empty()) return op->emitError() << kind << " must carry at least one chunk"; auto shapedType = dyn_cast(type); if (!shapedType || !shapedType.hasStaticShape()) return op->emitError() << kind << " requires a static shaped tensor or memref"; int64_t elementBits = shapedType.getElementTypeBitWidth(); if (elementBits <= 0 || elementBits % 8 != 0) return op->emitError() << kind << " requires byte-sized elements"; int64_t totalBytes = shapedType.getNumElements() * elementBits / 8; if (totalBytes % static_cast(coreIds.size()) != 0) return op->emitError() << kind << " tensor byte size must be divisible by the number of core ids"; return success(); } static LogicalResult verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef coreIds, StringRef kind) { if (coreIds.empty()) return op->emitError() << kind << " must carry at least one chunk"; auto coreBatchOp = op->getParentOfType(); if (!coreBatchOp) return op->emitError() << kind << " must be nested inside pim.core_batch"; int32_t laneCount = coreBatchOp.getLaneCount(); if (laneCount <= 0) return op->emitError() << kind << " requires a positive parent laneCount"; if (coreIds.size() % static_cast(laneCount) != 0) return op->emitError() << kind << " core id count must be divisible by the parent laneCount"; auto shapedType = dyn_cast(type); if (!shapedType || !shapedType.hasStaticShape()) return op->emitError() << kind << " requires a static shaped tensor or memref"; int64_t elementBits = shapedType.getElementTypeBitWidth(); if (elementBits <= 0 || elementBits % 8 != 0) return op->emitError() << kind << " requires byte-sized elements"; int64_t chunkCount = static_cast(coreIds.size()) / laneCount; int64_t totalBytes = shapedType.getNumElements() * elementBits / 8; if (totalBytes % chunkCount != 0) return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane"; return success(); } static FailureOr> getWeightShapeForVMM(Value weight) { auto shapedType = dyn_cast(weight.getType()); if (!shapedType) return failure(); return shapedType.getShape(); } } // namespace LogicalResult PimCoreOp::verify() { Block& block = getBody().front(); if (block.getNumArguments() != getWeights().size()) return emitError("core body must have one block argument per weight"); for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) if (getWeightArgument(weightIndex).getType() != weight.getType()) return emitError("core weight block argument types must match weight operand types exactly"); return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core"); } LogicalResult PimCoreBatchOp::verify() { if (getLaneCount() <= 0) return emitError("laneCount must be positive"); Block& block = getBody().front(); unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size(); if (block.getNumArguments() != expectedArgCount) return emitError("core_batch body must have lane, weight, and input block arguments"); if (!getLaneArgument().getType().isIndex()) return emitError("core_batch first block argument must have index type"); for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) if (getWeightArgument(weightIndex).getType() != weight.getType()) return emitError("core_batch weight block argument types must match weight operand types exactly"); for (auto [inputIndex, input] : llvm::enumerate(getInputs())) if (getInputArgument(inputIndex).getType() != input.getType()) return emitError("core_batch input block argument types must match input operand types exactly"); return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch"); } LogicalResult PimSendTensorOp::verify() { return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor"); } LogicalResult PimSendTensorBatchOp::verify() { return verifyTensorBatchCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor_batch"); } LogicalResult PimReceiveTensorOp::verify() { if (failed(verifyCompatibleShapedTypes( getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match"))) return failure(); return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor"); } LogicalResult PimReceiveTensorBatchOp::verify() { if (failed(verifyCompatibleShapedTypes( getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match"))) return failure(); return verifyTensorBatchCommunication( getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch"); } LogicalResult PimVMMOp::verify() { if (failed(verifyCompatibleShapedTypes( getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match"))) return failure(); auto matrixShapeOpt = getWeightShapeForVMM(getWeight()); if (failed(matrixShapeOpt)) return emitError("weight must be a shaped value"); ArrayRef matrixShape = *matrixShapeOpt; auto vectorType = dyn_cast(getInput().getType()); auto outputType = dyn_cast(getOutput().getType()); if (!vectorType || !outputType) return emitError("input and output must be shaped types"); ArrayRef vectorShape = vectorType.getShape(); ArrayRef outputShape = outputType.getShape(); if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) return emitError("matrix, vector and output must have rank 2"); int64_t N = matrixShape[0]; int64_t M = matrixShape[1]; if (N <= 0 || M <= 0) return emitError("matrix shape must be (N, M) with N > 0 and M > 0"); if (N > static_cast(crossbarSize) || M > static_cast(crossbarSize)) return emitError("matrix dimensions must fit in one crossbar"); int64_t vector1 = vectorShape[0]; int64_t vectorWidth = vectorShape[1]; if (vector1 != 1 || vectorWidth != static_cast(crossbarSize)) return emitError("vector shape must be (1, crossbar-size)"); int64_t output1 = outputShape[0]; int64_t outputWidth = outputShape[1]; if (output1 != 1 || outputWidth != static_cast(crossbarSize)) return emitError("output shape must be (1, crossbar-size)"); return success(); } LogicalResult PimConcatOp::verify() { if (getInputs().empty()) return emitError("requires at least one input"); if (failed(verifyCompatibleShapedTypes( getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match"))) return failure(); auto outputType = dyn_cast(getOutput().getType()); if (!outputType || !outputType.hasRank()) return emitError("output must be a ranked shaped type"); int64_t axis = getAxis(); int64_t rank = outputType.getRank(); if (axis < 0 || axis >= rank) return emitError("axis must be within the output rank"); int64_t concatenatedDimSize = 0; bool concatenatedDimDynamic = false; Type outputElementType = outputType.getElementType(); for (Value input : getInputs()) { auto inputType = dyn_cast(input.getType()); if (!inputType || !inputType.hasRank()) return emitError("inputs must be ranked shaped types"); if (!haveSameShapedContainerKind(input.getType(), getOutput().getType())) return emitError("inputs and output must use the same shaped container kind"); if (inputType.getRank() != rank) return emitError("all inputs must have the same rank as the output"); if (inputType.getElementType() != outputElementType) return emitError("all inputs must have the same element type as the output"); for (int64_t dim = 0; dim < rank; ++dim) { if (dim == axis) continue; int64_t inputDim = inputType.getDimSize(dim); int64_t outputDim = outputType.getDimSize(dim); if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim) return emitError("non-concatenated dimensions must match the output shape"); } int64_t inputConcatDim = inputType.getDimSize(axis); if (ShapedType::isDynamic(inputConcatDim)) { concatenatedDimDynamic = true; continue; } concatenatedDimSize += inputConcatDim; } int64_t outputConcatDim = outputType.getDimSize(axis); if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim) return emitError("output concatenated dimension must equal the sum of input sizes"); return success(); } } // namespace pim } // namespace onnx_mlir