340 lines
12 KiB
C++
340 lines
12 KiB
C++
#include "mlir/Dialect/Shape/IR/Shape.h"
|
|
#include "mlir/Dialect/Traits.h"
|
|
#include "mlir/IR/Block.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
#include "mlir/IR/IntegerSet.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/SmallBitVector.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/LogicalResult.h"
|
|
|
|
#include <cstdint>
|
|
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace spatial {
|
|
|
|
void SpatialDialect::initialize() {
|
|
addTypes<
|
|
#define GET_TYPEDEF_LIST
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"
|
|
|
|
>();
|
|
addOperations<
|
|
#define GET_OP_LIST
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc"
|
|
|
|
>();
|
|
}
|
|
|
|
inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
|
|
ArrayRef<int64_t>& matrixShape,
|
|
ArrayRef<int64_t>& vectorShape,
|
|
ArrayRef<int64_t>& outputShape) {
|
|
|
|
// Verify that the matrix, vector and output shapes have rank 2
|
|
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
|
return emitter->emitError("matrix, vector and output must have rank 2");
|
|
|
|
// Verify that the matrix shape is (N, M)
|
|
int64_t N = matrixShape[0];
|
|
int64_t M = matrixShape[1];
|
|
if (N <= 0 || M <= 0)
|
|
return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
|
|
|
// Verify that the vector shape is (M, 1)
|
|
int64_t vectorM = vectorShape[0];
|
|
int64_t vector1 = vectorShape[1];
|
|
if (vectorM != M || vector1 != 1)
|
|
return emitter->emitError("vector shape must be (M, 1)");
|
|
|
|
// Verify that the output shape is (N, 1)
|
|
int64_t outputN = outputShape[0];
|
|
int64_t output1 = outputShape[1];
|
|
if (outputN != N || output1 != 1)
|
|
return emitter->emitError("output shape must be (N, 1)");
|
|
|
|
return success();
|
|
}
|
|
|
|
inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
|
|
ArrayRef<int64_t>& matrixShape,
|
|
ArrayRef<int64_t>& vectorShape,
|
|
ArrayRef<int64_t>& outputShape) {
|
|
|
|
// Verify that the matrix, vector and output shapes have rank 4
|
|
if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4)
|
|
return emitter->emitError("matrix, vector and output must have rank 4");
|
|
|
|
// Verify that the matrix shape is (N, M, 1, 1)
|
|
int64_t N = matrixShape[0];
|
|
int64_t M = matrixShape[1];
|
|
int64_t matrix1First = matrixShape[2];
|
|
int64_t matrix1Second = matrixShape[3];
|
|
if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1)
|
|
return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0");
|
|
|
|
// Verify that the vector shape is (1, M, 1, 1)
|
|
int64_t vector1First = vectorShape[0];
|
|
int64_t vectorM = vectorShape[1];
|
|
int64_t vector1Second = vectorShape[2];
|
|
int64_t vector1Third = vectorShape[3];
|
|
if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) {
|
|
if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) {
|
|
// This is ok, it was caused by the simplification of the concat error
|
|
}
|
|
else {
|
|
return emitter->emitError("vector shape must be (1, M, 1, 1)");
|
|
}
|
|
}
|
|
|
|
// Verify that the output shape is (1, N, 1, 1)
|
|
int64_t output1First = outputShape[0];
|
|
int64_t outputN = outputShape[1];
|
|
int64_t output1Second = outputShape[2];
|
|
int64_t output1Third = outputShape[3];
|
|
if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1)
|
|
return emitter->emitError("output shape must be (1, N, 1, 1)");
|
|
|
|
return success();
|
|
}
|
|
|
|
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
|
|
auto wcomputeOp = dyn_cast<SpatWeightedCompute>(weigthedOp->getParentOp());
|
|
if (wcomputeOp)
|
|
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
|
|
|
|
auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp());
|
|
|
|
if (coreOp)
|
|
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
|
|
|
return failure();
|
|
}
|
|
|
|
LogicalResult SpatWeightedMVMOp::verify() {
|
|
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
|
if (failed(matrixShapeOpt))
|
|
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
|
|
auto matrixShape = *matrixShapeOpt;
|
|
auto vectorShape = getVector().getType().getShape();
|
|
auto outputShape = getOutput().getType().getShape();
|
|
|
|
/* Two possible accepted shapes:
|
|
1. matrix: (N, M); vector: (M, 1); output: (N, 1)
|
|
2. matrix: (N, M, 1, 1); vector: (1, M, 1, 1); output: (1, N, 1, 1)
|
|
*/
|
|
|
|
if (matrixShape.size() == 2)
|
|
return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape);
|
|
else if (matrixShape.size() == 4)
|
|
return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape);
|
|
else
|
|
return emitError("matrix rank must be 2 or 4");
|
|
}
|
|
|
|
LogicalResult SpatWeightedVMMOp::verify() {
|
|
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
|
if (failed(matrixShapeOpt))
|
|
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
|
|
auto matrixShape = *matrixShapeOpt;
|
|
auto vectorShape = getVector().getType().getShape();
|
|
auto outputShape = getOutput().getType().getShape();
|
|
|
|
/* Accepted shape:
|
|
1. vector: (1, N); matrix: (N, M); output: (1, M)
|
|
*/
|
|
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
|
return emitError("matrix, vector and output must have rank 2");
|
|
|
|
int64_t N = matrixShape[0];
|
|
int64_t M = matrixShape[1];
|
|
if (N <= 0 || M <= 0)
|
|
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
|
|
|
int64_t vector1 = vectorShape[0];
|
|
int64_t vectorN = vectorShape[1];
|
|
if (vectorN != N || vector1 != 1)
|
|
return emitError("vector shape must be (N, 1)");
|
|
|
|
int64_t output1 = outputShape[0];
|
|
int64_t outputM = outputShape[1];
|
|
if (outputM != M || output1 != 1)
|
|
return emitError("output shape must be (M, 1)");
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult SpatVAddOp::verify() {
|
|
// At least two operands
|
|
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
|
return failure();
|
|
|
|
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
|
}
|
|
|
|
LogicalResult SpatVMaxOp::verify() {
|
|
// At least two operands
|
|
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
|
return failure();
|
|
|
|
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
|
}
|
|
|
|
LogicalResult SpatImgConcatOp::verify() {
|
|
auto imgShape = mlir::cast<ShapedType>(getType());
|
|
size_t img_w = getImageWidth(imgShape);
|
|
size_t img_h = getImageHeight(imgShape);
|
|
size_t img_c = getImageChannel(imgShape);
|
|
|
|
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
|
size_t channelTileRest = img_c % crossbarSize;
|
|
|
|
auto operands = getOperands();
|
|
|
|
// Check number of operands
|
|
if (img_w * img_h * channelTiles != operands.size())
|
|
return emitError("Number of operands does not match output image size");
|
|
|
|
// For each output pixel, check that the inputTiles have a correct shape
|
|
for (size_t x = 0; x < img_w; x++) {
|
|
for (size_t y = 0; y < img_h; y++) {
|
|
size_t channel_counts = 0;
|
|
for (size_t t = 0; t < channelTiles; t++) {
|
|
auto inputShape = mlir::cast<ShapedType>(getInputTile(x, y, t).getType());
|
|
if (!inputShape)
|
|
return emitError("Invalid input type, must be ShapedType");
|
|
|
|
// N == W == H == 1
|
|
if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
|
|
return emitError("Invalid input shape: N,W,H must all be 1");
|
|
|
|
size_t inputChannels = getImageChannel(inputShape);
|
|
|
|
// Check the number of channels in this tile are correct:
|
|
// - CASE1: last tile of pixel, if there is some rest it must match that
|
|
// - CASE2: common case, the channel count is exactly the crossbarSize
|
|
if (t == channelTiles - 1 && channelTileRest != 0) {
|
|
if (inputChannels != channelTileRest)
|
|
return emitError("Invalid channel count for last tile of pixel");
|
|
}
|
|
else {
|
|
if (inputChannels != crossbarSize)
|
|
return emitError("Invalid channel count for some pixel tile");
|
|
}
|
|
|
|
channel_counts += inputChannels;
|
|
}
|
|
|
|
if (channel_counts != img_c)
|
|
emitError("Invalid number of channels for some pixel");
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult SpatWeightedCompute::verify() {
|
|
// Check that it has a terminator, it is a yieldOp, and it has a single
|
|
// operand with the same type as the result
|
|
auto& block = getBody().front();
|
|
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
|
if (!yieldOp)
|
|
return emitError("ComputeOp must have a single yield operation");
|
|
|
|
auto resultTypes = getResultTypes();
|
|
auto yieldTypes = yieldOp->getOperandTypes();
|
|
if (resultTypes.size() != yieldTypes.size()) {
|
|
return emitError("ComputeOp must have same number of results as yieldOp "
|
|
"operands");
|
|
}
|
|
|
|
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
|
|
auto resultType = std::get<0>(it);
|
|
auto yieldType = std::get<1>(it);
|
|
|
|
// Same type and compatible shape
|
|
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) {
|
|
return emitError("ComputeOp output must be of the same type as yieldOp "
|
|
"operand");
|
|
}
|
|
|
|
// Same encoding
|
|
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
|
|
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
|
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) {
|
|
return emitError("ComputeOp output must have the same encoding as "
|
|
"yieldOp operand");
|
|
}
|
|
}
|
|
else {
|
|
return emitError("ComputeOp output has an encoding while yieldOp "
|
|
"operand does not have one");
|
|
}
|
|
}
|
|
else {
|
|
// If result does not have an encoding, yield shouldn't either
|
|
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
|
return emitError("ComputeOp output must not have an encoding if "
|
|
"yieldOp operand has one");
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check that each block argument is used
|
|
for (auto arg : block.getArguments())
|
|
if (arg.use_empty())
|
|
return emitError("ComputeOp block argument is not used");
|
|
|
|
return success();
|
|
}
|
|
|
|
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
|
|
auto operands = getOperands();
|
|
auto imgShape = mlir::cast<ShapedType>(getType());
|
|
size_t img_w = getImageWidth(imgShape);
|
|
size_t img_h = getImageHeight(imgShape);
|
|
size_t img_c = getImageChannel(imgShape);
|
|
|
|
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
|
|
|
|
assert(tile < channelTiles);
|
|
assert(x < img_w);
|
|
assert(y < img_h);
|
|
|
|
return operands[tile + x * channelTiles + y * img_w * channelTiles];
|
|
}
|
|
|
|
} // namespace spatial
|
|
} // namespace onnx_mlir
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd op method definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc"
|
|
|
|
#define GET_TYPEDEF_CLASSES
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.cpp.inc"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"
|