This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
#include "mlir/IR/Diagnostics.h"
|
#include "mlir/IR/Diagnostics.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseSet.h"
|
#include "llvm/ADT/DenseSet.h"
|
||||||
#include "llvm/Support/LogicalResult.h"
|
#include "llvm/Support/LogicalResult.h"
|
||||||
@@ -338,6 +339,19 @@ LogicalResult SpatConcatOp::verify() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult verifyComputeResultsUses(Operation* op) {
|
||||||
|
if (!isa<SpatCompute, SpatComputeBatch>(op))
|
||||||
|
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
|
||||||
|
if (!llvm::all_of(op->getResults(), [](Value result) {
|
||||||
|
return llvm::all_of(result.getUsers(), [](Operation* op) {
|
||||||
|
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
|
||||||
|
});
|
||||||
|
})) {
|
||||||
|
return op->emitError("ComputeResult used directly inside another Compute" );
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult SpatCompute::verify() {
|
LogicalResult SpatCompute::verify() {
|
||||||
auto& block = getBody().front();
|
auto& block = getBody().front();
|
||||||
if (block.mightHaveTerminator()) {
|
if (block.mightHaveTerminator()) {
|
||||||
@@ -375,7 +389,8 @@ LogicalResult SpatCompute::verify() {
|
|||||||
for (auto arg : block.getArguments())
|
for (auto arg : block.getArguments())
|
||||||
if (arg.use_empty())
|
if (arg.use_empty())
|
||||||
return emitError("ComputeOp block argument is not used");
|
return emitError("ComputeOp block argument is not used");
|
||||||
|
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||||
|
return failure();
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -485,6 +500,8 @@ LogicalResult SpatComputeBatch::verify() {
|
|||||||
return emitError("body block argument type must match input type");
|
return emitError("body block argument type must match input type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||||
|
return failure();
|
||||||
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
|
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+1
@@ -65,6 +65,7 @@ std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
|||||||
if (!op)
|
if (!op)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
|
//TODO Extract Slice is not the only global non compute operation. There are other legal op
|
||||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
value = extract.getSource();
|
value = extract.getSource();
|
||||||
op = value.getDefiningOp();
|
op = value.getDefiningOp();
|
||||||
|
|||||||
Reference in New Issue
Block a user