This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
@@ -338,6 +339,19 @@ LogicalResult SpatConcatOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult verifyComputeResultsUses(Operation* op) {
|
||||
if (!isa<SpatCompute, SpatComputeBatch>(op))
|
||||
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
|
||||
if (!llvm::all_of(op->getResults(), [](Value result) {
|
||||
return llvm::all_of(result.getUsers(), [](Operation* op) {
|
||||
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
|
||||
});
|
||||
})) {
|
||||
return op->emitError("ComputeResult used directly inside another Compute" );
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatCompute::verify() {
|
||||
auto& block = getBody().front();
|
||||
if (block.mightHaveTerminator()) {
|
||||
@@ -375,7 +389,8 @@ LogicalResult SpatCompute::verify() {
|
||||
for (auto arg : block.getArguments())
|
||||
if (arg.use_empty())
|
||||
return emitError("ComputeOp block argument is not used");
|
||||
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -485,6 +500,8 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
return emitError("body block argument type must match input type");
|
||||
}
|
||||
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
|
||||
}
|
||||
|
||||
|
||||
+1
@@ -65,6 +65,7 @@ std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
|
||||
//TODO Extract Slice is not the only global non compute operation. There are other legal op
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
value = extract.getSource();
|
||||
op = value.getDefiningOp();
|
||||
|
||||
Reference in New Issue
Block a user