This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
@@ -82,20 +85,11 @@ inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
|
||||
if (auto computeOp = weightedOp->getParentOfType<SpatCompute>())
|
||||
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
if (auto coreOp = weightedOp->getParentOfType<pim::PimCoreOp>())
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
if (auto batchOp = weightedOp->getParentOfType<SpatComputeBatch>()) {
|
||||
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
return failure();
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
|
||||
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
||||
if (!shapedType)
|
||||
return failure();
|
||||
return shapedType.getShape();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
@@ -105,15 +99,86 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
return batchOp.getLaneCount();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorChannelSizes(Operation* op,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
StringRef kind) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
||||
if (batchOp.getNumResults() == 0)
|
||||
return false;
|
||||
auto blockArg = dyn_cast<BlockArgument>(value);
|
||||
if (!blockArg || blockArg.getOwner() != &batchOp.getBody().front())
|
||||
return false;
|
||||
|
||||
unsigned argNumber = blockArg.getArgNumber();
|
||||
unsigned firstOutputArg = batchOp.getOutputArgument(0).getArgNumber();
|
||||
return argNumber >= firstOutputArg && argNumber < firstOutputArg + batchOp.getNumResults();
|
||||
}
|
||||
|
||||
static bool isConstantIndexLike(Value value) {
|
||||
APInt constantValue;
|
||||
return matchPattern(value, m_ConstantInt(&constantValue));
|
||||
}
|
||||
|
||||
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
||||
if (value == laneArg || isConstantIndexLike(value))
|
||||
return true;
|
||||
|
||||
auto addOp = value.getDefiningOp<arith::AddIOp>();
|
||||
if (!addOp)
|
||||
return false;
|
||||
return (addOp.getLhs() == laneArg && isConstantIndexLike(addOp.getRhs()))
|
||||
|| (addOp.getRhs() == laneArg && isConstantIndexLike(addOp.getLhs()));
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgument laneArg, StringRef kind) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(sliceOp.getSource().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
|
||||
if (!sliceOp.hasUnitStride())
|
||||
return sliceOp.emitOpError() << kind << " requires unit strides";
|
||||
|
||||
for (int64_t size : sliceOp.getStaticSizes())
|
||||
if (ShapedType::isDynamic(size))
|
||||
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
||||
|
||||
auto offsets = sliceOp.getOffsets();
|
||||
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
||||
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
||||
if (!supported)
|
||||
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::ParallelInsertSliceOp sliceOp,
|
||||
BlockArgument laneArg,
|
||||
StringRef kind) {
|
||||
RankedTensorType sourceType = sliceOp.getSourceType();
|
||||
RankedTensorType destType = sliceOp.getDestType();
|
||||
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
|
||||
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
|
||||
if (!sliceOp.hasUnitStride())
|
||||
return sliceOp.emitOpError() << kind << " requires unit strides";
|
||||
|
||||
for (int64_t size : sliceOp.getStaticSizes())
|
||||
if (ShapedType::isDynamic(size))
|
||||
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
||||
|
||||
auto offsets = sliceOp.getOffsets();
|
||||
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
||||
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
||||
if (!supported)
|
||||
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorChannelSizes(
|
||||
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
if (channelIds.empty())
|
||||
if (channelCount == 0)
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
@@ -125,40 +190,34 @@ static LogicalResult verifyTensorChannelSizes(Operation* op,
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % static_cast<int64_t>(channelIds.size()) != 0)
|
||||
if (totalBytes % static_cast<int64_t>(channelCount) != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchChannelSizes(Operation* op,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
static LogicalResult
|
||||
verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelIds.size() != static_cast<size_t>(*laneCount))
|
||||
if (channelCount != static_cast<size_t>(*laneCount))
|
||||
return op->emitError("channel metadata length must match parent laneCount");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
StringRef kind) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
static LogicalResult verifyTensorBatchChannelSizes(
|
||||
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelIds.empty() || channelIds.size() % static_cast<size_t>(*laneCount) != 0)
|
||||
if (channelCount == 0 || channelCount % static_cast<size_t>(*laneCount) != 0)
|
||||
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
@@ -169,7 +228,7 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t chunkCount = static_cast<int64_t>(channelIds.size()) / *laneCount;
|
||||
int64_t chunkCount = static_cast<int64_t>(channelCount) / *laneCount;
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % chunkCount != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
||||
@@ -177,28 +236,59 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return op->emitError("body must terminate with spat.yield");
|
||||
if (outputTypes.empty()) {
|
||||
static Region* getParentRegion(Value value) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
return blockArg.getOwner()->getParent();
|
||||
if (Operation* definingOp = value.getDefiningOp())
|
||||
return definingOp->getParentRegion();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool isDefinedInsideRegion(Value value, Region& region) {
|
||||
Region* parentRegion = getParentRegion(value);
|
||||
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
||||
}
|
||||
|
||||
static bool isConstantExternalValue(Value value) {
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
||||
}
|
||||
|
||||
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
|
||||
bool hasFailure = false;
|
||||
region.walk([&](Operation* op) {
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value value = operand.get();
|
||||
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
|
||||
continue;
|
||||
|
||||
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
||||
<< kind << " body may only directly reference external constants";
|
||||
diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber()
|
||||
<< " is used by " << op->getName();
|
||||
hasFailure = true;
|
||||
}
|
||||
});
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
|
||||
if (batchOp.getNumResults() == 0) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return batchOp.emitError("resultless compute_batch body must terminate with spat.yield");
|
||||
if (yieldOp.getNumOperands() != 0)
|
||||
return op->emitError("body yield must be empty when compute_batch has no results");
|
||||
return batchOp.emitError("resultless compute_batch body yield must be empty");
|
||||
}
|
||||
else {
|
||||
if (yieldOp.getNumOperands() != 1)
|
||||
return op->emitError("body yield must produce exactly one value");
|
||||
if (yieldOp.getOperand(0).getType() != outputTypes[0])
|
||||
return op->emitError("body yield type must match output type");
|
||||
else if (!isa_and_nonnull<SpatInParallelOp>(block.getTerminator())) {
|
||||
return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel");
|
||||
}
|
||||
|
||||
BlockArgument laneArg = batchOp.getLaneArgument();
|
||||
for (auto& bodyOp : block) {
|
||||
if (auto wvmm = dyn_cast<SpatVMMOp>(&bodyOp))
|
||||
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
|
||||
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
|
||||
if (auto wmvm = dyn_cast<SpatMVMOp>(&bodyOp))
|
||||
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
|
||||
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
|
||||
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp))
|
||||
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, laneArg, "tensor.extract_slice")))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@@ -206,9 +296,9 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
|
||||
} // namespace
|
||||
|
||||
LogicalResult SpatMVMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatMVMOp was not within a SpatCompute or Core op");
|
||||
return emitError("weight must be a shaped value");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -221,9 +311,9 @@ LogicalResult SpatMVMOp::verify() {
|
||||
}
|
||||
|
||||
LogicalResult SpatVMMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatVMMOp was not within a SpatCompute or Core op");
|
||||
return emitError("weight must be a shaped value");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -347,13 +437,26 @@ LogicalResult verifyComputeResultsUses(Operation* op) {
|
||||
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
|
||||
});
|
||||
})) {
|
||||
return op->emitError("ComputeResult used directly inside another Compute" );
|
||||
return op->emitError("ComputeResult used directly inside another Compute");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatCompute::verify() {
|
||||
auto& block = getBody().front();
|
||||
unsigned expectedArgCount = getWeights().size() + getInputs().size();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute body must have weight and input block arguments");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("compute weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
if (getInputArgument(inputIndex).getType() != input.getType())
|
||||
return emitError("compute input block argument types must match input operand types exactly");
|
||||
}
|
||||
|
||||
if (block.mightHaveTerminator()) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
@@ -386,9 +489,11 @@ LogicalResult SpatCompute::verify() {
|
||||
}
|
||||
}
|
||||
|
||||
for (auto arg : block.getArguments())
|
||||
if (arg.use_empty())
|
||||
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
|
||||
if (getInputArgument(inputIndex).use_empty())
|
||||
return emitError("ComputeOp block argument is not used");
|
||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
|
||||
return failure();
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
return success();
|
||||
@@ -397,44 +502,46 @@ LogicalResult SpatCompute::verify() {
|
||||
LogicalResult SpatChannelSendTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_send_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_receive_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
return verifyBatchChannelSizes(
|
||||
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_send_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
return verifyBatchChannelSizes(
|
||||
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_receive_tensor_batch");
|
||||
}
|
||||
|
||||
@@ -444,35 +551,6 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
return emitError("laneCount must be positive");
|
||||
|
||||
auto laneCountSz = static_cast<size_t>(count);
|
||||
if (getWeights().size() % laneCountSz != 0)
|
||||
return emitError("number of weights must be a multiple of laneCount");
|
||||
|
||||
if (!getInputs().empty() && getInputs().size() != laneCountSz)
|
||||
return emitError("number of inputs must be either 0 or laneCount");
|
||||
if (!getOutputs().empty() && getOutputs().size() != laneCountSz)
|
||||
return emitError("number of outputs must be either 0 or laneCount");
|
||||
|
||||
size_t weightsPerLane = getWeights().size() / laneCountSz;
|
||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) {
|
||||
Type weightType = getWeights()[weightIndex].getType();
|
||||
for (size_t lane = 1; lane < laneCountSz; ++lane)
|
||||
if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType)
|
||||
return emitError("corresponding weights across lanes must have the same type");
|
||||
}
|
||||
|
||||
if (!getInputs().empty()) {
|
||||
Type inputType = getInputs()[0].getType();
|
||||
for (Value in : getInputs().drop_front())
|
||||
if (in.getType() != inputType)
|
||||
return emitError("all inputs must have the same type");
|
||||
}
|
||||
|
||||
if (!getOutputs().empty()) {
|
||||
Type outputType = getOutputs()[0].getType();
|
||||
for (Value out : getOutputs().drop_front())
|
||||
if (out.getType() != outputType)
|
||||
return emitError("all outputs must have the same type");
|
||||
}
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
|
||||
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
||||
@@ -482,27 +560,64 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
return emitError("compute_batch coreIds array length must match laneCount");
|
||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
||||
return emitError("compute_batch coreIds values must be non-negative");
|
||||
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
|
||||
DenseSet<int32_t> seenCoreIds;
|
||||
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
||||
if (!seenCoreIds.insert(coreId).second)
|
||||
return emitError("compute_batch coreIds values must be distinct");
|
||||
return emitError("compute_batch coreIds values must be unique");
|
||||
}
|
||||
|
||||
Block& block = getBody().front();
|
||||
if (getInputs().empty()) {
|
||||
if (block.getNumArguments() != 0)
|
||||
return emitError("compute_batch body must have no block arguments when there are no inputs");
|
||||
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute_batch body must have lane, weight, input, and output block arguments");
|
||||
if (!getLaneArgument().getType().isIndex())
|
||||
return emitError("compute_batch first block argument must have index type");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("compute_batch weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
else {
|
||||
if (block.getNumArguments() != 1)
|
||||
return emitError("compute_batch body must have exactly one block argument");
|
||||
if (block.getArgument(0).getType() != getInputs()[0].getType())
|
||||
return emitError("body block argument type must match input type");
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
BlockArgument blockArg = getInputArgument(inputIndex);
|
||||
if (blockArg.getType() != input.getType())
|
||||
return emitError("compute_batch input block argument types must match input operand types exactly");
|
||||
}
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
|
||||
BlockArgument blockArg = getOutputArgument(resultIndex);
|
||||
if (blockArg.getType() != resultType)
|
||||
return emitError("compute_batch output block argument types must match result types exactly");
|
||||
}
|
||||
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
|
||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
|
||||
return failure();
|
||||
return verifyBatchBody(*this, block);
|
||||
}
|
||||
|
||||
LogicalResult SpatInParallelOp::verify() {
|
||||
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
|
||||
if (!batchOp)
|
||||
return emitOpError("expected spat.compute_batch parent");
|
||||
if (batchOp.getNumResults() == 0)
|
||||
return emitOpError("requires a resultful spat.compute_batch parent");
|
||||
|
||||
BlockArgument laneArg = batchOp.getLaneArgument();
|
||||
for (Operation& op : getRegion().front().getOperations()) {
|
||||
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
|
||||
if (!insertSliceOp)
|
||||
return emitOpError("expected only tensor.parallel_insert_slice ops");
|
||||
|
||||
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, laneArg, "tensor.parallel_insert_slice")))
|
||||
return failure();
|
||||
|
||||
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
|
||||
for (OpOperand& destination : destinations)
|
||||
if (!isBatchOutputArgument(batchOp, destination.get()))
|
||||
return op.emitOpError("may only insert into a compute_batch output block argument");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
Reference in New Issue
Block a user