fix much stuff
This commit is contained in:
@@ -107,8 +107,11 @@ static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
||||
return false;
|
||||
|
||||
unsigned argNumber = blockArg.getArgNumber();
|
||||
unsigned firstOutputArg = batchOp.getOutputArgument(0).getArgNumber();
|
||||
return argNumber >= firstOutputArg && argNumber < firstOutputArg + batchOp.getNumResults();
|
||||
auto firstOutputArg = batchOp.getOutputArgument(0);
|
||||
if (!firstOutputArg)
|
||||
return false;
|
||||
unsigned firstOutputArgNumber = firstOutputArg->getArgNumber();
|
||||
return argNumber >= firstOutputArgNumber && argNumber < firstOutputArgNumber + batchOp.getNumResults();
|
||||
}
|
||||
|
||||
static bool isConstantIndexLike(Value value) {
|
||||
@@ -293,10 +296,12 @@ static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
|
||||
return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel");
|
||||
}
|
||||
|
||||
BlockArgument laneArg = batchOp.getLaneArgument();
|
||||
auto laneArg = batchOp.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return batchOp.emitError("compute_batch body must have a lane block argument");
|
||||
for (auto& bodyOp : block) {
|
||||
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp))
|
||||
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, laneArg, "tensor.extract_slice")))
|
||||
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, *laneArg, "tensor.extract_slice")))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
@@ -457,12 +462,16 @@ LogicalResult SpatCompute::verify() {
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute body must have weight and input block arguments");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
auto blockArg = getWeightArgument(weightIndex);
|
||||
if (!blockArg || blockArg->getType() != weight.getType())
|
||||
return emitError("compute weight block argument types must match weight operand types exactly");
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
|
||||
if (getInputArgument(inputIndex).getType() != input.getType())
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
auto blockArg = getInputArgument(inputIndex);
|
||||
if (!blockArg || blockArg->getType() != input.getType())
|
||||
return emitError("compute input block argument types must match input operand types exactly");
|
||||
}
|
||||
|
||||
if (block.mightHaveTerminator()) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
@@ -497,7 +506,7 @@ LogicalResult SpatCompute::verify() {
|
||||
}
|
||||
|
||||
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
|
||||
if (getInputArgument(inputIndex).use_empty())
|
||||
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
|
||||
return emitError("ComputeOp block argument is not used");
|
||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
|
||||
return failure();
|
||||
@@ -574,23 +583,28 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
}
|
||||
|
||||
Block& block = getBody().front();
|
||||
if (block.getNumArguments() == 0)
|
||||
return emitError("compute_batch body must have exactly one lane block argument");
|
||||
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute_batch body must have lane, weight, input, and output block arguments");
|
||||
if (!getLaneArgument().getType().isIndex())
|
||||
return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
|
||||
auto laneArg = getLaneArgument();
|
||||
if (!laneArg || !laneArg->getType().isIndex())
|
||||
return emitError("compute_batch first block argument must have index type");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
auto blockArg = getWeightArgument(weightIndex);
|
||||
if (!blockArg || blockArg->getType() != weight.getType())
|
||||
return emitError("compute_batch weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
BlockArgument blockArg = getInputArgument(inputIndex);
|
||||
if (blockArg.getType() != input.getType())
|
||||
auto blockArg = getInputArgument(inputIndex);
|
||||
if (!blockArg || blockArg->getType() != input.getType())
|
||||
return emitError("compute_batch input block argument types must match input operand types exactly");
|
||||
}
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
|
||||
BlockArgument blockArg = getOutputArgument(resultIndex);
|
||||
if (blockArg.getType() != resultType)
|
||||
auto blockArg = getOutputArgument(resultIndex);
|
||||
if (!blockArg || blockArg->getType() != resultType)
|
||||
return emitError("compute_batch output block argument types must match result types exactly");
|
||||
}
|
||||
|
||||
@@ -608,13 +622,15 @@ LogicalResult SpatInParallelOp::verify() {
|
||||
if (batchOp.getNumResults() == 0)
|
||||
return emitOpError("requires a resultful spat.compute_batch parent");
|
||||
|
||||
BlockArgument laneArg = batchOp.getLaneArgument();
|
||||
auto laneArg = batchOp.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return emitOpError("expected compute_batch lane block argument");
|
||||
for (Operation& op : getRegion().front().getOperations()) {
|
||||
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
|
||||
if (!insertSliceOp)
|
||||
return emitOpError("expected only tensor.parallel_insert_slice ops");
|
||||
|
||||
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, laneArg, "tensor.parallel_insert_slice")))
|
||||
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, *laneArg, "tensor.parallel_insert_slice")))
|
||||
return failure();
|
||||
|
||||
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
|
||||
|
||||
Reference in New Issue
Block a user