Files
Raptor/src/PIM/Dialect/Spatial/SpatialOps.cpp
NiccoloN 6e1de865bb add constant folding and verification pass for pim host operations
better validation scripts output
big refactors
2026-03-20 12:08:12 +01:00

340 lines
12 KiB
C++

#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"
#include <cstdint>
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
void SpatialDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc"
>();
}
inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
// Verify that the matrix, vector and output shapes have rank 2
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitter->emitError("matrix, vector and output must have rank 2");
// Verify that the matrix shape is (N, M)
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0");
// Verify that the vector shape is (M, 1)
int64_t vectorM = vectorShape[0];
int64_t vector1 = vectorShape[1];
if (vectorM != M || vector1 != 1)
return emitter->emitError("vector shape must be (M, 1)");
// Verify that the output shape is (N, 1)
int64_t outputN = outputShape[0];
int64_t output1 = outputShape[1];
if (outputN != N || output1 != 1)
return emitter->emitError("output shape must be (N, 1)");
return success();
}
inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
ArrayRef<int64_t>& matrixShape,
ArrayRef<int64_t>& vectorShape,
ArrayRef<int64_t>& outputShape) {
// Verify that the matrix, vector and output shapes have rank 4
if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4)
return emitter->emitError("matrix, vector and output must have rank 4");
// Verify that the matrix shape is (N, M, 1, 1)
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
int64_t matrix1First = matrixShape[2];
int64_t matrix1Second = matrixShape[3];
if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1)
return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0");
// Verify that the vector shape is (1, M, 1, 1)
int64_t vector1First = vectorShape[0];
int64_t vectorM = vectorShape[1];
int64_t vector1Second = vectorShape[2];
int64_t vector1Third = vectorShape[3];
if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) {
if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) {
// This is ok, it was caused by the simplification of the concat error
}
else {
return emitter->emitError("vector shape must be (1, M, 1, 1)");
}
}
// Verify that the output shape is (1, N, 1, 1)
int64_t output1First = outputShape[0];
int64_t outputN = outputShape[1];
int64_t output1Second = outputShape[2];
int64_t output1Third = outputShape[3];
if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1)
return emitter->emitError("output shape must be (1, N, 1, 1)");
return success();
}
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
auto wcomputeOp = dyn_cast<SpatWeightedCompute>(weigthedOp->getParentOp());
if (wcomputeOp)
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp());
if (coreOp)
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
return failure();
}
LogicalResult SpatWeightedMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getVector().getType().getShape();
auto outputShape = getOutput().getType().getShape();
/* Two possible accepted shapes:
1. matrix: (N, M); vector: (M, 1); output: (N, 1)
2. matrix: (N, M, 1, 1); vector: (1, M, 1, 1); output: (1, N, 1, 1)
*/
if (matrixShape.size() == 2)
return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape);
else if (matrixShape.size() == 4)
return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape);
else
return emitError("matrix rank must be 2 or 4");
}
LogicalResult SpatWeightedVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getVector().getType().getShape();
auto outputShape = getOutput().getType().getShape();
/* Accepted shape:
1. vector: (1, N); matrix: (N, M); output: (1, M)
*/
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
return emitError("matrix, vector and output must have rank 2");
int64_t N = matrixShape[0];
int64_t M = matrixShape[1];
if (N <= 0 || M <= 0)
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
int64_t vector1 = vectorShape[0];
int64_t vectorN = vectorShape[1];
if (vectorN != N || vector1 != 1)
return emitError("vector shape must be (N, 1)");
int64_t output1 = outputShape[0];
int64_t outputM = outputShape[1];
if (outputM != M || output1 != 1)
return emitError("output shape must be (M, 1)");
return success();
}
LogicalResult SpatVAddOp::verify() {
// At least two operands
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatVMaxOp::verify() {
// At least two operands
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
return failure();
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatImgConcatOp::verify() {
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = getImageWidth(imgShape);
size_t img_h = getImageHeight(imgShape);
size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
size_t channelTileRest = img_c % crossbarSize;
auto operands = getOperands();
// Check number of operands
if (img_w * img_h * channelTiles != operands.size())
return emitError("Number of operands does not match output image size");
// For each output pixel, check that the inputTiles have a correct shape
for (size_t x = 0; x < img_w; x++) {
for (size_t y = 0; y < img_h; y++) {
size_t channel_counts = 0;
for (size_t t = 0; t < channelTiles; t++) {
auto inputShape = mlir::cast<ShapedType>(getInputTile(x, y, t).getType());
if (!inputShape)
return emitError("Invalid input type, must be ShapedType");
// N == W == H == 1
if (getImageN(inputShape) != 1 || getImageWidth(inputShape) != 1 || getImageHeight(inputShape) != 1)
return emitError("Invalid input shape: N,W,H must all be 1");
size_t inputChannels = getImageChannel(inputShape);
// Check the number of channels in this tile are correct:
// - CASE1: last tile of pixel, if there is some rest it must match that
// - CASE2: common case, the channel count is exactly the crossbarSize
if (t == channelTiles - 1 && channelTileRest != 0) {
if (inputChannels != channelTileRest)
return emitError("Invalid channel count for last tile of pixel");
}
else {
if (inputChannels != crossbarSize)
return emitError("Invalid channel count for some pixel tile");
}
channel_counts += inputChannels;
}
if (channel_counts != img_c)
emitError("Invalid number of channels for some pixel");
}
}
return success();
}
LogicalResult SpatWeightedCompute::verify() {
// Check that it has a terminator, it is a yieldOp, and it has a single
// operand with the same type as the result
auto& block = getBody().front();
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return emitError("ComputeOp must have a single yield operation");
auto resultTypes = getResultTypes();
auto yieldTypes = yieldOp->getOperandTypes();
if (resultTypes.size() != yieldTypes.size()) {
return emitError("ComputeOp must have same number of results as yieldOp "
"operands");
}
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
auto resultType = std::get<0>(it);
auto yieldType = std::get<1>(it);
// Same type and compatible shape
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType))) {
return emitError("ComputeOp output must be of the same type as yieldOp "
"operand");
}
// Same encoding
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding()) {
return emitError("ComputeOp output must have the same encoding as "
"yieldOp operand");
}
}
else {
return emitError("ComputeOp output has an encoding while yieldOp "
"operand does not have one");
}
}
else {
// If result does not have an encoding, yield shouldn't either
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
return emitError("ComputeOp output must not have an encoding if "
"yieldOp operand has one");
}
}
}
// Check that each block argument is used
for (auto arg : block.getArguments())
if (arg.use_empty())
return emitError("ComputeOp block argument is not used");
return success();
}
Value SpatImgConcatOp::getInputTile(size_t x, size_t y, size_t tile) {
auto operands = getOperands();
auto imgShape = mlir::cast<ShapedType>(getType());
size_t img_w = getImageWidth(imgShape);
size_t img_h = getImageHeight(imgShape);
size_t img_c = getImageChannel(imgShape);
size_t channelTiles = ceilIntegerDivide(img_c, crossbarSize.getValue());
assert(tile < channelTiles);
assert(x < img_w);
assert(y < img_h);
return operands[tile + x * channelTiles + y * img_w * channelTiles];
}
} // namespace spatial
} // namespace onnx_mlir
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.cpp.inc"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialTypes.cpp.inc"