This commit is contained in:
@@ -0,0 +1,268 @@
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
namespace {
|
||||
|
||||
static LogicalResult verifyManyCommunicationSizes(Operation* op, ArrayRef<int32_t> coreIds, size_t valueCount) {
|
||||
if (coreIds.size() != valueCount)
|
||||
return op->emitError("core id metadata length must match the number of values");
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
|
||||
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
|
||||
}
|
||||
|
||||
static LogicalResult verifyCompatibleShapedTypes(Operation* op, Type lhs, Type rhs, StringRef message) {
|
||||
auto lhsShaped = dyn_cast<ShapedType>(lhs);
|
||||
auto rhsShaped = dyn_cast<ShapedType>(rhs);
|
||||
if (!lhsShaped || !rhsShaped || !haveSameShapedContainerKind(lhs, rhs))
|
||||
return op->emitError(message);
|
||||
if (lhsShaped.getElementType() != rhsShaped.getElementType() || lhsShaped.getShape() != rhsShaped.getShape())
|
||||
return op->emitError(message);
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyCommunicationTypes(Operation* op, TypeRange types, StringRef kind) {
|
||||
if (types.empty())
|
||||
return op->emitError() << kind << " must carry at least one value";
|
||||
|
||||
Type firstType = types.front();
|
||||
auto firstShapedType = dyn_cast<ShapedType>(firstType);
|
||||
bool firstIsTensor = isa<RankedTensorType>(firstType);
|
||||
bool firstIsMemRef = isa<MemRefType>(firstType);
|
||||
for (Type type : types.drop_front())
|
||||
if (type != firstType) {
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
if (!firstShapedType || !shapedType)
|
||||
return op->emitError() << kind << " values must all have the same type";
|
||||
if (firstIsTensor != isa<RankedTensorType>(type) || firstIsMemRef != isa<MemRefType>(type))
|
||||
return op->emitError() << kind << " values must all use the same shaped container kind";
|
||||
if (firstShapedType.getElementType() != shapedType.getElementType() || firstShapedType.getShape() != shapedType.getShape())
|
||||
return op->emitError() << kind << " values must all have the same shape and element type";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>();
|
||||
if (!coreBatchOp)
|
||||
return failure();
|
||||
return coreBatchOp.getLaneCount();
|
||||
}
|
||||
|
||||
static LogicalResult verifyManyBatchCommunicationSizes(Operation* op,
|
||||
ArrayRef<int32_t> coreIds,
|
||||
size_t valueCount) {
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside pim.core_batch");
|
||||
if (coreIds.size() != valueCount * static_cast<size_t>(*laneCount))
|
||||
return op->emitError("core id metadata length must match the number of values times parent laneCount");
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult PimEmptyManyOp::verify() {
|
||||
if (getOutputs().empty())
|
||||
return emitError("must produce at least one output");
|
||||
|
||||
Type firstType = getOutputs().front().getType();
|
||||
auto firstTensorType = dyn_cast<RankedTensorType>(firstType);
|
||||
if (!firstTensorType)
|
||||
return emitError("outputs must all be ranked tensor types");
|
||||
|
||||
for (Value output : getOutputs().drop_front())
|
||||
if (output.getType() != firstType)
|
||||
return emitError("outputs must all have the same type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimMapOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
if (getOutputs().size() != getInputs().size())
|
||||
return emitError("number of outputs must match number of inputs");
|
||||
|
||||
Type inputType = getInputs().front().getType();
|
||||
for (Value input : getInputs().drop_front())
|
||||
if (input.getType() != inputType)
|
||||
return emitError("all inputs must have the same type");
|
||||
|
||||
Type outputType = getOutputs().front().getType();
|
||||
for (Value output : getOutputs().drop_front())
|
||||
if (output.getType() != outputType)
|
||||
return emitError("all outputs must have the same type");
|
||||
|
||||
Block& block = getBody().front();
|
||||
if (block.getNumArguments() != 1)
|
||||
return emitError("body must have exactly one block argument");
|
||||
if (block.getArgument(0).getType() != inputType)
|
||||
return emitError("body block argument type must match input type");
|
||||
|
||||
auto yieldOp = dyn_cast_or_null<PimYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return emitError("body must terminate with pim.yield");
|
||||
if (yieldOp.getNumOperands() != 1)
|
||||
return emitError("body yield must produce exactly one value");
|
||||
if (yieldOp.getOperand(0).getType() != outputType)
|
||||
return emitError("body yield type must match output type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimSendManyOp::verify() {
|
||||
if (failed(verifyManyCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
|
||||
return failure();
|
||||
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many");
|
||||
}
|
||||
|
||||
LogicalResult PimSendManyBatchOp::verify() {
|
||||
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getTargetCoreIds(), getInputs().size())))
|
||||
return failure();
|
||||
return verifyManyCommunicationTypes(getOperation(), getInputs().getTypes(), "send_many_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveManyOp::verify() {
|
||||
if (getOutputBuffers().size() != getOutputs().size())
|
||||
return emitError("number of output buffers must match the number of outputs");
|
||||
if (failed(verifyManyCommunicationSizes(getOperation(), getSourceCoreIds(), getOutputs().size())))
|
||||
return failure();
|
||||
|
||||
if (failed(verifyManyCommunicationTypes(getOperation(), getOutputBuffers().getTypes(), "receive_many")))
|
||||
return failure();
|
||||
if (failed(verifyManyCommunicationTypes(getOperation(), getOperation()->getResultTypes(), "receive_many")))
|
||||
return failure();
|
||||
|
||||
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs()))
|
||||
if (outputBuffer.getType() != output.getType())
|
||||
return emitError("output buffers and outputs must have matching types");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimReceiveManyBatchOp::verify() {
|
||||
if (getOutputBuffers().size() != getOutputs().size())
|
||||
return emitError("number of output buffers must match the number of outputs");
|
||||
if (failed(verifyManyBatchCommunicationSizes(getOperation(), getSourceCoreIds(), getOutputs().size())))
|
||||
return failure();
|
||||
|
||||
if (failed(verifyManyCommunicationTypes(getOperation(), getOutputBuffers().getTypes(), "receive_many_batch")))
|
||||
return failure();
|
||||
if (failed(verifyManyCommunicationTypes(getOperation(), getOperation()->getResultTypes(), "receive_many_batch")))
|
||||
return failure();
|
||||
|
||||
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs()))
|
||||
if (outputBuffer.getType() != output.getType())
|
||||
return emitError("output buffers and outputs must have matching types");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimExtractRowsOp::verify() {
|
||||
if (getOutputBuffers().size() != getOutputs().size())
|
||||
return emitError("number of output buffers must match the number of outputs");
|
||||
|
||||
auto inputType = dyn_cast<ShapedType>(getInput().getType());
|
||||
if (!inputType || !inputType.hasRank() || inputType.getRank() != 2)
|
||||
return emitError("input must be a rank-2 shaped type");
|
||||
|
||||
int64_t numRows = inputType.getShape()[0];
|
||||
int64_t numCols = inputType.getShape()[1];
|
||||
Type elementType = inputType.getElementType();
|
||||
|
||||
if (numRows >= 0 && static_cast<int64_t>(getOutputs().size()) != numRows)
|
||||
return emitError("number of outputs must match the number of input rows");
|
||||
|
||||
for (auto [outputBuffer, output] : llvm::zip(getOutputBuffers(), getOutputs())) {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), outputBuffer.getType(), output.getType(), "output buffers and outputs must match")))
|
||||
return failure();
|
||||
|
||||
auto outputType = dyn_cast<ShapedType>(output.getType());
|
||||
if (!outputType || !outputType.hasRank() || outputType.getRank() != 2)
|
||||
return emitError("outputs must all be rank-2 shaped types");
|
||||
if (!haveSameShapedContainerKind(getInput().getType(), output.getType()))
|
||||
return emitError("outputs must use the same shaped container kind as the input");
|
||||
if (outputType.getElementType() != elementType)
|
||||
return emitError("output element types must match input element type");
|
||||
auto outputShape = outputType.getShape();
|
||||
if (outputShape[0] != 1)
|
||||
return emitError("each output must have exactly one row");
|
||||
if (numCols >= 0 && outputShape[1] != numCols)
|
||||
return emitError("output column count must match input column count");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimConcatOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
return failure();
|
||||
|
||||
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
|
||||
if (!outputType || !outputType.hasRank())
|
||||
return emitError("output must be a ranked shaped type");
|
||||
|
||||
int64_t axis = getAxis();
|
||||
int64_t rank = outputType.getRank();
|
||||
if (axis < 0 || axis >= rank)
|
||||
return emitError("axis must be within the output rank");
|
||||
|
||||
int64_t concatenatedDimSize = 0;
|
||||
bool concatenatedDimDynamic = false;
|
||||
Type outputElementType = outputType.getElementType();
|
||||
|
||||
for (Value input : getInputs()) {
|
||||
auto inputType = dyn_cast<ShapedType>(input.getType());
|
||||
if (!inputType || !inputType.hasRank())
|
||||
return emitError("inputs must be ranked shaped types");
|
||||
if (!haveSameShapedContainerKind(input.getType(), getOutput().getType()))
|
||||
return emitError("inputs and output must use the same shaped container kind");
|
||||
if (inputType.getRank() != rank)
|
||||
return emitError("all inputs must have the same rank as the output");
|
||||
if (inputType.getElementType() != outputElementType)
|
||||
return emitError("all inputs must have the same element type as the output");
|
||||
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
if (dim == axis)
|
||||
continue;
|
||||
int64_t inputDim = inputType.getDimSize(dim);
|
||||
int64_t outputDim = outputType.getDimSize(dim);
|
||||
if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim)
|
||||
return emitError("non-concatenated dimensions must match the output shape");
|
||||
}
|
||||
|
||||
int64_t inputConcatDim = inputType.getDimSize(axis);
|
||||
if (ShapedType::isDynamic(inputConcatDim)) {
|
||||
concatenatedDimDynamic = true;
|
||||
continue;
|
||||
}
|
||||
concatenatedDimSize += inputConcatDim;
|
||||
}
|
||||
|
||||
int64_t outputConcatDim = outputType.getDimSize(axis);
|
||||
if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim)
|
||||
return emitError("output concatenated dimension must equal the sum of input sizes");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user