#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/LogicalResult.h" #include #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; namespace onnx_mlir { namespace spatial { void SpatialDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST #include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc" >(); addOperations< #define GET_OP_LIST #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc" >(); } inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter, ArrayRef& matrixShape, ArrayRef& vectorShape, ArrayRef& outputShape) { // Verify that the matrix, vector and output shapes have rank 2 if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) return emitter->emitError("matrix, vector and output must have rank 2"); // Verify that the matrix shape is (N, M) int64_t N = matrixShape[0]; int64_t M = matrixShape[1]; if (N <= 0 || M <= 0) return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0"); // Verify that the vector shape is (M, 1) int64_t vectorM = vectorShape[0]; int64_t vector1 = vectorShape[1]; if (vectorM != M || vector1 != 1) return emitter->emitError("vector shape must be (M, 1)"); // Verify that the output shape is (N, 1) int64_t outputN = outputShape[0]; int64_t output1 = outputShape[1]; if (outputN != N || output1 != 1) return emitter->emitError("output shape must be (N, 1)"); return success(); } inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter, ArrayRef& matrixShape, ArrayRef& vectorShape, ArrayRef& outputShape) { // Verify that the matrix, vector and output shapes have rank 4 if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4) return emitter->emitError("matrix, vector and output must have rank 4"); // Verify that the matrix shape is (N, M, 1, 1) int64_t N = matrixShape[0]; int64_t M = matrixShape[1]; int64_t matrix1First = matrixShape[2]; int64_t matrix1Second = matrixShape[3]; if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1) return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0"); // Verify that the vector shape is (1, M, 1, 1) int64_t vector1First = vectorShape[0]; int64_t vectorM = vectorShape[1]; int64_t vector1Second = vectorShape[2]; int64_t vector1Third = vectorShape[3]; if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) { if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) { // This is ok, it was caused by the simplification of the concat error } else { return emitter->emitError("vector shape must be (1, M, 1, 1)"); } } // Verify that the output shape is (1, N, 1, 1) int64_t output1First = outputShape[0]; int64_t outputN = outputShape[1]; int64_t output1Second = outputShape[2]; int64_t output1Third = outputShape[3]; if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1) return emitter->emitError("output shape must be (1, N, 1, 1)"); return success(); } llvm::FailureOr> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) { auto wcomputeOp = dyn_cast(weigthedOp->getParentOp()); if (wcomputeOp) return cast(wcomputeOp.getWeights()[weightIndex].getType()).getShape(); auto coreOp = dyn_cast(weigthedOp->getParentOp()); if (coreOp) return cast(coreOp.getWeights()[weightIndex].getType()).getShape(); return failure(); } LogicalResult SpatWeightedMVMOp::verify() { auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); if (failed(matrixShapeOpt)) return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op"); auto matrixShape = *matrixShapeOpt; auto vectorShape = getVector().getType().getShape(); auto outputShape = getOutput().getType().getShape(); /* Two possible accepted shapes: 1. matrix: (N, M); vector: (M, 1); output: (N, 1) 2. matrix: (N, M, 1, 1); vector: (1, M, 1, 1); output: (1, N, 1, 1) */ if (matrixShape.size() == 2) return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape); else if (matrixShape.size() == 4) return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape); else return emitError("matrix rank must be 2 or 4"); } LogicalResult SpatWeightedVMMOp::verify() { auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex()); if (failed(matrixShapeOpt)) return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op"); auto matrixShape = *matrixShapeOpt; auto vectorShape = getVector().getType().getShape(); auto outputShape = getOutput().getType().getShape(); /* Accepted shape: 1. vector: (1, N); matrix: (N, M); output: (1, M) */ if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2) return emitError("matrix, vector and output must have rank 2"); int64_t N = matrixShape[0]; int64_t M = matrixShape[1]; if (N <= 0 || M <= 0) return emitError("matrix shape must be (N, M) with N > 0 and M > 0"); int64_t vector1 = vectorShape[0]; int64_t vectorN = vectorShape[1]; if (vectorN != N || vector1 != 1) return emitError("vector shape must be (N, 1)"); int64_t output1 = outputShape[0]; int64_t outputM = outputShape[1]; if (outputM != M || output1 != 1) return emitError("output shape must be (M, 1)"); return success(); } LogicalResult SpatVAddOp::verify() { // At least two operands if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) return failure(); return OpTrait::impl::verifySameOperandsAndResultType(*this); } LogicalResult SpatVMaxOp::verify() { // At least two operands if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) return failure(); return OpTrait::impl::verifySameOperandsAndResultType(*this); } LogicalResult SpatImgConcatOp::verify() { auto imgShape = mlir::cast(getType()); size_t img_w = getImageWidth(imgShape); size_t img_h = getImageHeight(imgShape); size_t img_c = getImageChannel(imgShape); size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue()); size_t channelTileRest = img_c % crossbarSize; auto operands = getOperands(); // Check number of operands if (img_w * img_h * channelTiles != operands.size()) return emitError("Number of operands does not match output image size"); // For each output pixel, check that the inputTiles have a correct shape for (size_t x = 0; x < img_w; x++) { for (size_t y = 0; y < img_h; y++) { size_t channel_counts = 0; for (size_t t = 0; t < channelTiles; t++) { auto inputShape = mlir::cast(getInputTile(x, y, t).getType()); if (!inputShape) return emitError("Invalid input type, must be ShapedType"); // N == W == H == 1 if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1) return emitError("Invalid input shape: N,W,H must all be 1"); size_t inputChannels = getImageChannel(inputShape); // Check the number of channels in this tile are correct: // - CASE1: last tile of pixel, if there is some rest it must match that // - CASE2: common case, the channel count is exactly the crossbarSize if (t == channelTiles - 1 && channelTileRest != 0) { if (inputChannels != channelTileRest) return emitError("Invalid channel count for last tile of pixel"); } else { if (inputChannels != crossbarSize) return emitError("Invalid channel count for some pixel tile"); } channel_counts += inputChannels; } if (channel_counts != img_c) emitError("Invalid number of channels for some pixel"); } } return success(); } LogicalResult SpatWeightedCompute::verify() { // Check that it has a terminator, it is a yieldOp, and it has a single // operand with the same type as the result auto& block = getBody().front(); auto yieldOp = dyn_cast_or_null(block.getTerminator()); if (!yieldOp) return emitError("ComputeOp must have a single yield operation"); auto resultTypes = getResultTypes(); auto yieldTypes = yieldOp->getOperandTypes(); if (resultTypes.size() != yieldTypes.size()) { return emitError("ComputeOp must have same number of results as yieldOp " "operands"); } for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) { auto resultType = std::get<0>(it); auto yieldType = std::get<1>(it); // Same type and compatible shape if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) { return emitError("ComputeOp output must be of the same type as yieldOp " "operand"); } // Same encoding if (auto resultRankedType = dyn_cast(resultType)) { if (auto yieldRankedType = dyn_cast(yieldType)) { if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) { return emitError("ComputeOp output must have the same encoding as " "yieldOp operand"); } } else { return emitError("ComputeOp output has an encoding while yieldOp " "operand does not have one"); } } else { // If result does not have an encoding, yield shouldn't either if (auto yieldRankedType = dyn_cast(yieldType)) { return emitError("ComputeOp output must not have an encoding if " "yieldOp operand has one"); } } } // Check that each block argument is used for (auto arg : block.getArguments()) if (arg.use_empty()) return emitError("ComputeOp block argument is not used"); return success(); } Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) { auto operands = getOperands(); auto imgShape = mlir::cast(getType()); size_t img_w = getImageWidth(imgShape); size_t img_h = getImageHeight(imgShape); size_t img_c = getImageChannel(imgShape); size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue()); assert(tile < channelTiles); assert(x < img_w); assert(y < img_h); return operands[tile + x * channelTiles + y * img_w * channelTiles]; } } // namespace spatial } // namespace onnx_mlir //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc" #define GET_TYPEDEF_CLASSES #include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.cpp.inc" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"