b79ce8eeaa
Validate Operations / validate-operations (push) Has been cancelled
run dce at the end of MaterializeMergeSchedule to get rid of unused constants
684 lines
28 KiB
C++
684 lines
28 KiB
C++
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/Block.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/Diagnostics.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
|
|
#include "llvm/ADT/DenseSet.h"
|
|
#include "llvm/Support/LogicalResult.h"
|
|
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace spatial {
|
|
|
|
namespace {
|
|
|
|
inline LogicalResult mvmOpVerifySize2(SpatMVMOp* emitter,
|
|
ArrayRef<int64_t>& matrixShape,
|
|
ArrayRef<int64_t>& vectorShape,
|
|
ArrayRef<int64_t>& outputShape) {
|
|
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
|
return emitter->emitError("matrix, vector and output must have rank 2");
|
|
|
|
int64_t N = matrixShape[0];
|
|
int64_t M = matrixShape[1];
|
|
if (N <= 0 || M <= 0)
|
|
return emitter->emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
|
|
|
int64_t vectorM = vectorShape[0];
|
|
int64_t vector1 = vectorShape[1];
|
|
if (vectorM != M || vector1 != 1)
|
|
return emitter->emitError("vector shape must be (M, 1)");
|
|
|
|
int64_t outputN = outputShape[0];
|
|
int64_t output1 = outputShape[1];
|
|
if (outputN != N || output1 != 1)
|
|
return emitter->emitError("output shape must be (N, 1)");
|
|
|
|
return success();
|
|
}
|
|
|
|
inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
|
|
ArrayRef<int64_t>& matrixShape,
|
|
ArrayRef<int64_t>& vectorShape,
|
|
ArrayRef<int64_t>& outputShape) {
|
|
if (matrixShape.size() != 4 || vectorShape.size() != 4 || outputShape.size() != 4)
|
|
return emitter->emitError("matrix, vector and output must have rank 4");
|
|
|
|
int64_t N = matrixShape[0];
|
|
int64_t M = matrixShape[1];
|
|
int64_t matrix1First = matrixShape[2];
|
|
int64_t matrix1Second = matrixShape[3];
|
|
if (N <= 0 || M <= 0 || matrix1First != 1 || matrix1Second != 1)
|
|
return emitter->emitError("matrix shape must be (N, M, 1, 1) with N > 0 and M > 0");
|
|
|
|
int64_t vector1First = vectorShape[0];
|
|
int64_t vectorM = vectorShape[1];
|
|
int64_t vector1Second = vectorShape[2];
|
|
int64_t vector1Third = vectorShape[3];
|
|
if (vector1First != 1 || vectorM != M || vector1Second != 1 || vector1Third != 1) {
|
|
if (vector1First == 1 && vector1Second == 1 && vector1Third == 1 && ignoreConcatError == true) {
|
|
// This is ok, it was caused by the simplification of the concat error.
|
|
}
|
|
else {
|
|
return emitter->emitError("vector shape must be (1, M, 1, 1)");
|
|
}
|
|
}
|
|
|
|
int64_t output1First = outputShape[0];
|
|
int64_t outputN = outputShape[1];
|
|
int64_t output1Second = outputShape[2];
|
|
int64_t output1Third = outputShape[3];
|
|
if (output1First != 1 || outputN != N || output1Second != 1 || output1Third != 1)
|
|
return emitter->emitError("output shape must be (1, N, 1, 1)");
|
|
|
|
return success();
|
|
}
|
|
|
|
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
|
|
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
|
if (!shapedType)
|
|
return failure();
|
|
return shapedType.getShape();
|
|
}
|
|
|
|
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
|
auto batchOp = op->getParentOfType<SpatComputeBatch>();
|
|
if (!batchOp)
|
|
return failure();
|
|
return batchOp.getLaneCount();
|
|
}
|
|
|
|
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
|
if (batchOp.getNumResults() == 0)
|
|
return false;
|
|
auto blockArg = dyn_cast<BlockArgument>(value);
|
|
if (!blockArg || blockArg.getOwner() != &batchOp.getBody().front())
|
|
return false;
|
|
|
|
unsigned argNumber = blockArg.getArgNumber();
|
|
auto firstOutputArg = batchOp.getOutputArgument(0);
|
|
if (!firstOutputArg)
|
|
return false;
|
|
unsigned firstOutputArgNumber = firstOutputArg->getArgNumber();
|
|
return argNumber >= firstOutputArgNumber && argNumber < firstOutputArgNumber + batchOp.getNumResults();
|
|
}
|
|
|
|
static bool isConstantIndexLike(Value value) {
|
|
APInt constantValue;
|
|
return matchPattern(value, m_ConstantInt(&constantValue));
|
|
}
|
|
|
|
static bool isSupportedLaneAffineExpr(AffineExpr expr) {
|
|
switch (expr.getKind()) {
|
|
case AffineExprKind::Constant:
|
|
case AffineExprKind::DimId: return true;
|
|
case AffineExprKind::SymbolId: return false;
|
|
case AffineExprKind::Add: {
|
|
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
|
|
return isSupportedLaneAffineExpr(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS());
|
|
}
|
|
case AffineExprKind::Mul: {
|
|
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
|
|
return (isa<AffineConstantExpr>(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS()))
|
|
|| (isa<AffineConstantExpr>(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS()));
|
|
}
|
|
case AffineExprKind::FloorDiv:
|
|
case AffineExprKind::CeilDiv:
|
|
case AffineExprKind::Mod: {
|
|
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
|
|
return isa<AffineConstantExpr>(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS());
|
|
}
|
|
}
|
|
llvm_unreachable("unexpected affine expression kind");
|
|
}
|
|
|
|
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
|
if (value == laneArg || isConstantIndexLike(value))
|
|
return true;
|
|
|
|
auto affineApply = value.getDefiningOp<affine::AffineApplyOp>();
|
|
if (affineApply) {
|
|
if (affineApply.getAffineMap().getNumResults() != 1 || affineApply.getAffineMap().getNumSymbols() != 0)
|
|
return false;
|
|
if (!llvm::all_of(affineApply.getMapOperands(),
|
|
[&](Value operand) { return isSupportedLaneOffsetExpr(operand, laneArg); })) {
|
|
return false;
|
|
}
|
|
return isSupportedLaneAffineExpr(affineApply.getAffineMap().getResult(0));
|
|
}
|
|
|
|
auto extractOp = value.getDefiningOp<tensor::ExtractOp>();
|
|
if (extractOp) {
|
|
auto constantTensor = extractOp.getTensor().getDefiningOp<arith::ConstantOp>();
|
|
auto denseAttr = constantTensor ? dyn_cast<DenseIntElementsAttr>(constantTensor.getValue()) : nullptr;
|
|
if (!denseAttr || denseAttr.getType().getRank() != 1 || extractOp.getIndices().size() != 1)
|
|
return false;
|
|
return isSupportedLaneOffsetExpr(extractOp.getIndices().front(), laneArg);
|
|
}
|
|
|
|
auto addOp = value.getDefiningOp<arith::AddIOp>();
|
|
if (!addOp)
|
|
return false;
|
|
return (addOp.getLhs() == laneArg && isConstantIndexLike(addOp.getRhs()))
|
|
|| (addOp.getRhs() == laneArg && isConstantIndexLike(addOp.getLhs()));
|
|
}
|
|
|
|
static LogicalResult
|
|
verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgument laneArg, StringRef kind) {
|
|
auto sourceType = dyn_cast<RankedTensorType>(sliceOp.getSource().getType());
|
|
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
|
|
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
|
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
|
|
if (!sliceOp.hasUnitStride())
|
|
return sliceOp.emitOpError() << kind << " requires unit strides";
|
|
|
|
for (int64_t size : sliceOp.getStaticSizes())
|
|
if (ShapedType::isDynamic(size))
|
|
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
|
|
|
auto offsets = sliceOp.getOffsets();
|
|
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
|
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
|
if (!supported)
|
|
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::ParallelInsertSliceOp sliceOp,
|
|
BlockArgument laneArg,
|
|
StringRef kind) {
|
|
RankedTensorType sourceType = sliceOp.getSourceType();
|
|
RankedTensorType destType = sliceOp.getDestType();
|
|
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
|
|
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
|
|
if (!sliceOp.hasUnitStride())
|
|
return sliceOp.emitOpError() << kind << " requires unit strides";
|
|
|
|
for (int64_t size : sliceOp.getStaticSizes())
|
|
if (ShapedType::isDynamic(size))
|
|
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
|
|
|
auto offsets = sliceOp.getOffsets();
|
|
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
|
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
|
if (!supported)
|
|
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult verifyTensorChannelSizes(
|
|
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
|
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
|
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
|
if (channelCount == 0)
|
|
return op->emitError() << kind << " must carry at least one chunk";
|
|
|
|
auto shapedType = dyn_cast<ShapedType>(type);
|
|
if (!shapedType || !shapedType.hasStaticShape())
|
|
return op->emitError() << kind << " requires a static shaped tensor";
|
|
|
|
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
|
if (elementBits <= 0 || elementBits % 8 != 0)
|
|
return op->emitError() << kind << " requires byte-sized elements";
|
|
|
|
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
|
if (totalBytes % static_cast<int64_t>(channelCount) != 0)
|
|
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult
|
|
verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) {
|
|
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
|
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
|
|
|
auto laneCount = getParentBatchLaneCount(op);
|
|
if (failed(laneCount))
|
|
return op->emitError("must be nested inside spat.compute_batch");
|
|
if (channelCount != static_cast<size_t>(*laneCount))
|
|
return op->emitError("channel metadata length must match parent laneCount");
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult verifyTensorBatchChannelSizes(
|
|
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
|
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
|
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
|
|
|
auto laneCount = getParentBatchLaneCount(op);
|
|
if (failed(laneCount))
|
|
return op->emitError("must be nested inside spat.compute_batch");
|
|
if (channelCount == 0 || channelCount % static_cast<size_t>(*laneCount) != 0)
|
|
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
|
|
|
|
auto shapedType = dyn_cast<ShapedType>(type);
|
|
if (!shapedType || !shapedType.hasStaticShape())
|
|
return op->emitError() << kind << " requires a static shaped tensor";
|
|
|
|
int64_t elementBits = shapedType.getElementTypeBitWidth();
|
|
if (elementBits <= 0 || elementBits % 8 != 0)
|
|
return op->emitError() << kind << " requires byte-sized elements";
|
|
|
|
int64_t chunkCount = static_cast<int64_t>(channelCount) / *laneCount;
|
|
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
|
if (totalBytes % chunkCount != 0)
|
|
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
|
|
|
return success();
|
|
}
|
|
|
|
static Region* getParentRegion(Value value) {
|
|
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
|
return blockArg.getOwner()->getParent();
|
|
if (Operation* definingOp = value.getDefiningOp())
|
|
return definingOp->getParentRegion();
|
|
return nullptr;
|
|
}
|
|
|
|
static bool isDefinedInsideRegion(Value value, Region& region) {
|
|
Region* parentRegion = getParentRegion(value);
|
|
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
|
}
|
|
|
|
static bool isConstantExternalValue(Value value) {
|
|
Operation* definingOp = value.getDefiningOp();
|
|
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
|
}
|
|
|
|
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
|
|
bool hasFailure = false;
|
|
region.walk([&](Operation* op) {
|
|
for (OpOperand& operand : op->getOpOperands()) {
|
|
Value value = operand.get();
|
|
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
|
|
continue;
|
|
|
|
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
|
<< kind << " body may only directly reference external constants";
|
|
diagnostic.attachNote(op->getLoc())
|
|
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
|
hasFailure = true;
|
|
}
|
|
});
|
|
return success(!hasFailure);
|
|
}
|
|
|
|
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
|
|
if (batchOp.getNumResults() == 0) {
|
|
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
|
if (!yieldOp)
|
|
return batchOp.emitError("resultless compute_batch body must terminate with spat.yield");
|
|
if (yieldOp.getNumOperands() != 0)
|
|
return batchOp.emitError("resultless compute_batch body yield must be empty");
|
|
}
|
|
else if (!isa_and_nonnull<SpatInParallelOp>(block.getTerminator())) {
|
|
return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel");
|
|
}
|
|
|
|
auto laneArg = batchOp.getLaneArgument();
|
|
if (!laneArg)
|
|
return batchOp.emitError("compute_batch body must have a lane block argument");
|
|
for (auto& bodyOp : block) {
|
|
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp))
|
|
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, *laneArg, "tensor.extract_slice")))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
LogicalResult SpatMVMOp::verify() {
|
|
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
|
|
if (failed(matrixShapeOpt))
|
|
return emitError("weight must be a shaped value");
|
|
auto matrixShape = *matrixShapeOpt;
|
|
auto vectorShape = getInput().getType().getShape();
|
|
auto outputShape = getOutput().getType().getShape();
|
|
|
|
if (matrixShape.size() == 2)
|
|
return mvmOpVerifySize2(this, matrixShape, vectorShape, outputShape);
|
|
if (matrixShape.size() == 4)
|
|
return mvmOpVerifySize4(this, matrixShape, vectorShape, outputShape);
|
|
return emitError("matrix rank must be 2 or 4");
|
|
}
|
|
|
|
LogicalResult SpatVMMOp::verify() {
|
|
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
|
|
if (failed(matrixShapeOpt))
|
|
return emitError("weight must be a shaped value");
|
|
auto matrixShape = *matrixShapeOpt;
|
|
auto vectorShape = getInput().getType().getShape();
|
|
auto outputShape = getOutput().getType().getShape();
|
|
|
|
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
|
return emitError("matrix, vector and output must have rank 2");
|
|
|
|
int64_t N = matrixShape[0];
|
|
int64_t M = matrixShape[1];
|
|
if (N <= 0 || M <= 0)
|
|
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
|
|
|
int64_t vector1 = vectorShape[0];
|
|
int64_t vectorN = vectorShape[1];
|
|
if (vectorN != N || vector1 != 1)
|
|
return emitError("vector shape must be (1, N)");
|
|
|
|
int64_t output1 = outputShape[0];
|
|
int64_t outputM = outputShape[1];
|
|
if (outputM != M || output1 != 1)
|
|
return emitError("output shape must be (1, M)");
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult SpatVAddOp::verify() {
|
|
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
|
return failure();
|
|
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
|
}
|
|
|
|
LogicalResult SpatVMaxOp::verify() {
|
|
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
|
return failure();
|
|
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
|
}
|
|
|
|
LogicalResult SpatExtractRowsOp::verify() {
|
|
auto inputType = dyn_cast<ShapedType>(getInput().getType());
|
|
if (!inputType || !inputType.hasRank() || inputType.getRank() != 2)
|
|
return emitError("input must be a rank-2 shaped type");
|
|
|
|
int64_t numRows = inputType.getShape()[0];
|
|
int64_t numCols = inputType.getShape()[1];
|
|
Type elementType = inputType.getElementType();
|
|
|
|
if (numRows >= 0 && static_cast<int64_t>(getNumResults()) != numRows)
|
|
return emitError("number of outputs must match the number of input rows");
|
|
|
|
for (Type output : getResultTypes()) {
|
|
auto outputType = dyn_cast<ShapedType>(output);
|
|
if (!outputType || !outputType.hasRank() || outputType.getRank() != 2)
|
|
return emitError("outputs must all be rank-2 shaped types");
|
|
if (outputType.getElementType() != elementType)
|
|
return emitError("output element types must match input element type");
|
|
auto outputShape = outputType.getShape();
|
|
if (outputShape[0] != 1)
|
|
return emitError("each output must have exactly one row");
|
|
if (numCols >= 0 && outputShape[1] != numCols)
|
|
return emitError("output column count must match input column count");
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult SpatConcatOp::verify() {
|
|
if (getInputs().empty())
|
|
return emitError("requires at least one input");
|
|
|
|
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
|
|
if (!outputType || !outputType.hasRank())
|
|
return emitError("output must be a ranked shaped type");
|
|
|
|
int64_t axis = getAxis();
|
|
int64_t rank = outputType.getRank();
|
|
if (axis < 0 || axis >= rank)
|
|
return emitError("axis must be within the output rank");
|
|
|
|
int64_t concatenatedDimSize = 0;
|
|
bool concatenatedDimDynamic = false;
|
|
Type outputElementType = outputType.getElementType();
|
|
|
|
for (Value input : getInputs()) {
|
|
auto inputType = dyn_cast<ShapedType>(input.getType());
|
|
if (!inputType || !inputType.hasRank())
|
|
return emitError("inputs must be ranked shaped types");
|
|
if (inputType.getRank() != rank)
|
|
return emitError("all inputs must have the same rank as the output");
|
|
if (inputType.getElementType() != outputElementType)
|
|
return emitError("all inputs must have the same element type as the output");
|
|
|
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
|
if (dim == axis)
|
|
continue;
|
|
int64_t inputDim = inputType.getDimSize(dim);
|
|
int64_t outputDim = outputType.getDimSize(dim);
|
|
if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim)
|
|
return emitError("non-concatenated dimensions must match the output shape");
|
|
}
|
|
|
|
int64_t inputConcatDim = inputType.getDimSize(axis);
|
|
if (ShapedType::isDynamic(inputConcatDim)) {
|
|
concatenatedDimDynamic = true;
|
|
continue;
|
|
}
|
|
concatenatedDimSize += inputConcatDim;
|
|
}
|
|
|
|
int64_t outputConcatDim = outputType.getDimSize(axis);
|
|
if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim)
|
|
return emitError("output concatenated dimension must equal the sum of input sizes");
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult verifyComputeResultsUses(Operation* op) {
|
|
if (!isa<SpatCompute, SpatComputeBatch>(op))
|
|
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
|
|
if (!llvm::all_of(op->getResults(), [](Value result) {
|
|
return llvm::all_of(result.getUsers(), [](Operation* op) {
|
|
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
|
|
});
|
|
})) {
|
|
return op->emitError("ComputeResult used directly inside another Compute");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult SpatCompute::verify() {
|
|
auto& block = getBody().front();
|
|
unsigned expectedArgCount = getWeights().size() + getInputs().size();
|
|
if (block.getNumArguments() != expectedArgCount)
|
|
return emitError("compute body must have weight and input block arguments");
|
|
|
|
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
|
auto blockArg = getWeightArgument(weightIndex);
|
|
if (!blockArg || blockArg->getType() != weight.getType())
|
|
return emitError("compute weight block argument types must match weight operand types exactly");
|
|
}
|
|
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
|
auto blockArg = getInputArgument(inputIndex);
|
|
if (!blockArg || blockArg->getType() != input.getType())
|
|
return emitError("compute input block argument types must match input operand types exactly");
|
|
}
|
|
|
|
if (block.mightHaveTerminator()) {
|
|
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
|
if (!yieldOp)
|
|
return emitError("ComputeOp must have a single yield operation");
|
|
|
|
auto resultTypes = getResultTypes();
|
|
auto yieldTypes = yieldOp->getOperandTypes();
|
|
if (resultTypes.size() != yieldTypes.size())
|
|
return emitError("ComputeOp must have same number of results as yieldOp operands");
|
|
|
|
for (auto it : llvm::reverse(llvm::zip(resultTypes, yieldTypes))) {
|
|
auto resultType = std::get<0>(it);
|
|
auto yieldType = std::get<1>(it);
|
|
|
|
if (resultType != yieldType || failed(verifyCompatibleShape(resultType, yieldType)))
|
|
return emitError("ComputeOp output must be of the same type as yieldOp operand");
|
|
|
|
if (auto resultRankedType = dyn_cast<RankedTensorType>(resultType)) {
|
|
if (auto yieldRankedType = dyn_cast<RankedTensorType>(yieldType)) {
|
|
if (resultRankedType.getEncoding() != yieldRankedType.getEncoding())
|
|
return emitError("ComputeOp output must have the same encoding as yieldOp operand");
|
|
}
|
|
else {
|
|
return emitError("ComputeOp output has an encoding while yieldOp operand does not have one");
|
|
}
|
|
}
|
|
else if (dyn_cast<RankedTensorType>(yieldType)) {
|
|
return emitError("ComputeOp output must not have an encoding if yieldOp operand has one");
|
|
}
|
|
}
|
|
}
|
|
|
|
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
|
|
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
|
|
return emitError("ComputeOp block argument is not used");
|
|
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
|
|
return failure();
|
|
if (failed(verifyComputeResultsUses(this->getOperation())))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
LogicalResult SpatChannelSendTensorOp::verify() {
|
|
return verifyTensorChannelSizes(getOperation(),
|
|
getInput().getType(),
|
|
getChannelIds().size(),
|
|
getSourceCoreIds().size(),
|
|
getTargetCoreIds().size(),
|
|
"channel_send_tensor");
|
|
}
|
|
|
|
LogicalResult SpatChannelReceiveTensorOp::verify() {
|
|
return verifyTensorChannelSizes(getOperation(),
|
|
getOutput().getType(),
|
|
getChannelIds().size(),
|
|
getSourceCoreIds().size(),
|
|
getTargetCoreIds().size(),
|
|
"channel_receive_tensor");
|
|
}
|
|
|
|
LogicalResult SpatChannelSendBatchOp::verify() {
|
|
return verifyBatchChannelSizes(
|
|
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
|
}
|
|
|
|
LogicalResult SpatChannelSendTensorBatchOp::verify() {
|
|
return verifyTensorBatchChannelSizes(getOperation(),
|
|
getInput().getType(),
|
|
getChannelIds().size(),
|
|
getSourceCoreIds().size(),
|
|
getTargetCoreIds().size(),
|
|
"channel_send_tensor_batch");
|
|
}
|
|
|
|
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
|
return verifyBatchChannelSizes(
|
|
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
|
}
|
|
|
|
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
|
|
return verifyTensorBatchChannelSizes(getOperation(),
|
|
getOutput().getType(),
|
|
getChannelIds().size(),
|
|
getSourceCoreIds().size(),
|
|
getTargetCoreIds().size(),
|
|
"channel_receive_tensor_batch");
|
|
}
|
|
|
|
LogicalResult SpatComputeBatch::verify() {
|
|
int32_t count = getLaneCount();
|
|
if (count <= 0)
|
|
return emitError("laneCount must be positive");
|
|
|
|
auto laneCountSz = static_cast<size_t>(count);
|
|
|
|
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
|
|
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
|
if (!coreIdsAttr)
|
|
return emitError("compute_batch coreIds attribute must be a dense i32 array");
|
|
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
|
|
return emitError("compute_batch coreIds array length must match laneCount");
|
|
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
|
return emitError("compute_batch coreIds values must be non-negative");
|
|
DenseSet<int32_t> seenCoreIds;
|
|
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
|
if (!seenCoreIds.insert(coreId).second)
|
|
return emitError("compute_batch coreIds values must be unique");
|
|
}
|
|
|
|
Block& block = getBody().front();
|
|
if (block.getNumArguments() == 0)
|
|
return emitError("compute_batch body must have exactly one lane block argument");
|
|
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
|
|
if (block.getNumArguments() != expectedArgCount)
|
|
return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
|
|
auto laneArg = getLaneArgument();
|
|
if (!laneArg || !laneArg->getType().isIndex())
|
|
return emitError("compute_batch first block argument must have index type");
|
|
|
|
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
|
auto blockArg = getWeightArgument(weightIndex);
|
|
if (!blockArg || blockArg->getType() != weight.getType())
|
|
return emitError("compute_batch weight block argument types must match weight operand types exactly");
|
|
}
|
|
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
|
auto blockArg = getInputArgument(inputIndex);
|
|
if (!blockArg || blockArg->getType() != input.getType())
|
|
return emitError("compute_batch input block argument types must match input operand types exactly");
|
|
}
|
|
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
|
|
auto blockArg = getOutputArgument(resultIndex);
|
|
if (!blockArg || blockArg->getType() != resultType)
|
|
return emitError("compute_batch output block argument types must match result types exactly");
|
|
}
|
|
|
|
if (failed(verifyComputeResultsUses(this->getOperation())))
|
|
return failure();
|
|
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
|
|
return failure();
|
|
return verifyBatchBody(*this, block);
|
|
}
|
|
|
|
LogicalResult SpatInParallelOp::verify() {
|
|
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
|
|
if (!batchOp)
|
|
return emitOpError("expected spat.compute_batch parent");
|
|
if (batchOp.getNumResults() == 0)
|
|
return emitOpError("requires a resultful spat.compute_batch parent");
|
|
|
|
auto laneArg = batchOp.getLaneArgument();
|
|
if (!laneArg)
|
|
return emitOpError("expected compute_batch lane block argument");
|
|
for (Operation& op : getRegion().front().getOperations()) {
|
|
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
|
|
if (!insertSliceOp)
|
|
return emitOpError("expected only tensor.parallel_insert_slice ops");
|
|
|
|
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, *laneArg, "tensor.parallel_insert_slice")))
|
|
return failure();
|
|
|
|
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
|
|
for (OpOperand& destination : destinations)
|
|
if (!isBatchOutputArgument(batchOp, destination.get()))
|
|
return op.emitOpError("may only insert into a compute_batch output block argument");
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
} // namespace spatial
|
|
} // namespace onnx_mlir
|