#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/LogicalResult.h" #include #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { namespace spatial { void SpatialDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST #include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc" >(); addOperations< #define GET_OP_LIST #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc" >(); } inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter, ArrayRef& matrixShape, ArrayRef& vectorShape, ArrayRef& outputShape) { // Verify that the matrix, vector and output shapes have rank 2 if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) return emitter->emitError("matrix, vector and output must have rank 2"); // Verify that the matrix shape is (N, M) int64_t N = matrixShape[0]; int64_t M = matrixShape[1]; if (N <= 0 || M <= 0) return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0"); // Verify that the vector shape is (M, 1) int64_t vectorM = vectorShape[0]; int64_t vector1 = vectorShape[1]; if (vectorM != M || vector1 != 1) return emitter->emitError("vector shape must be (M, 1)"); // Verify that the output shape is (N, 1) int64_t outputN = outputShape[0]; int64_t output1 = outputShape[1]; if (outputN != N || output1 != 1) return emitter->emitError("output shape must be (N, 1)"); return success(); } inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter, ArrayRef& matrixShape, ArrayRef& vectorShape, ArrayRef& outputShape) { // Verify that the matrix, vector and output shapes have rank 4 if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4) return emitter->emitError("matrix, vector and output must have rank 4"); // Verify that the matrix shape is (N, M, 1, 1) int64_t N = matrixShape[0]; int64_t M = matrixShape[1]; int64_t matrix1First = matrixShape[2]; int64_t matrix1Second = matrixShape[3]; if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1) return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0"); // Verify that the vector shape is (1, M, 1, 1) int64_t vector1First = vectorShape[0]; int64_t vectorM = vectorShape[1]; int64_t vector1Second = vectorShape[2]; int64_t vector1Third = vectorShape[3]; if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) { if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) { // This is ok, it was caused by the simplification of the concat error } else { return emitter->emitError("vector shape must be (1, M, 1, 1)"); } } // Verify that the output shape is (1, N, 1, 1) int64_t output1First = outputShape[0]; int64_t outputN = outputShape[1]; int64_t output1Second = outputShape[2]; int64_t output1Third = outputShape[3]; if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1) return emitter->emitError("output shape must be (1, N, 1, 1)"); return success(); } llvm::FailureOr> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) { auto wcomputeOp = dyn_cast(weigthedOp->getParentOp()); if (wcomputeOp) return cast(wcomputeOp.getWeights()[weightIndex].getType()).getShape(); auto coreOp = dyn_cast(weigthedOp->getParentOp()); if (coreOp) return cast(coreOp.getWeights()[weightIndex].getType()).getShape(); return failure(); } LogicalResult SpatWeightedMVMOp::verify() { auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); if (failed(matrixShapeOpt)) return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op"); auto matrixShape = *matrixShapeOpt; auto vectorShape = getInput().getType().getShape(); auto outputShape = getOutput().getType().getShape(); /* Two possible accepted shapes: 1. matrix: (N, M); vector: (M, 1); output: (N, 1) 2. matrix: (N, M, 1, 1); vector: (1, M, 1, 1); output: (1, N, 1, 1) */ if (matrixShape.size() == 2) return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape); else if (matrixShape.size() == 4) return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape); else return emitError("matrix rank must be 2 or 4"); } LogicalResult SpatWeightedVMMOp::verify() { auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); if (failed(matrixShapeOpt)) return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op"); auto matrixShape = *matrixShapeOpt; auto vectorShape = getInput().getType().getShape(); auto outputShape = getOutput().getType().getShape(); /* Accepted shape: 1. vector: (1, N); matrix: (N, M); output: (1, M) */ if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) return emitError("matrix, vector and output must have rank 2"); int64_t N = matrixShape[0]; int64_t M = matrixShape[1]; if (N <= 0 || M <= 0) return emitError("matrix shape must be (N, M) with N > 0 and M > 0"); int64_t vector1 = vectorShape[0]; int64_t vectorN = vectorShape[1]; if (vectorN != N || vector1 != 1) return emitError("vector shape must be (N, 1)"); int64_t output1 = outputShape[0]; int64_t outputM = outputShape[1]; if (outputM != M || output1 != 1) return emitError("output shape must be (M, 1)"); return success(); } LogicalResult SpatVAddOp::verify() { // At least two operands if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) return failure(); return OpTrait::impl::verifySameOperandsAndResultType(*this); } LogicalResult SpatVMaxOp::verify() { // At least two operands if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) return failure(); return OpTrait::impl::verifySameOperandsAndResultType(*this); } LogicalResult SpatWeightedCompute::verify() { // Check that it has a terminator, it is a yieldOp, and it has a single // operand with the same type as the result auto& block = getBody().front(); auto yieldOp = dyn_cast_or_null(block.getTerminator()); if (!yieldOp) return emitError("ComputeOp must have a single yield operation"); auto resultTypes = getResultTypes(); auto yieldTypes = yieldOp->getOperandTypes(); if (resultTypes.size() != yieldTypes.size()) { return emitError("ComputeOp must have same number of results as yieldOp " "operands"); } for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) { auto resultType = std::get<0>(it); auto yieldType = std::get<1>(it); // Same type and compatible shape if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) { return emitError("ComputeOp output must be of the same type as yieldOp " "operand"); } // Same encoding if (auto resultRankedType = dyn_cast(resultType)) { if (auto yieldRankedType = dyn_cast(yieldType)) { if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) { return emitError("ComputeOp output must have the same encoding as " "yieldOp operand"); } } else { return emitError("ComputeOp output has an encoding while yieldOp " "operand does not have one"); } } else { // If result does not have an encoding, yield shouldn't either if (auto yieldRankedType = dyn_cast(yieldType)) { return emitError("ComputeOp output must not have an encoding if " "yieldOp operand has one"); } } } // Check that each block argument is used for (auto arg : block.getArguments()) if (arg.use_empty()) return emitError("ComputeOp block argument is not used"); return success(); } } // namespace spatial } // namespace onnx_mlir //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc" #define GET_TYPEDEF_CLASSES #include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.cpp.inc" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"