add verification of static weights in spatial
Validate Operations / validate-operations (push) Waiting to run
Validate Operations / validate-operations (push) Waiting to run
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -116,6 +117,15 @@ static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
||||
return argNumber >= firstOutputArgNumber && argNumber < firstOutputArgNumber + batchOp.getNumResults();
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind) {
|
||||
for (Value weight : computeOp.getWeights()) {
|
||||
if (!isHostFoldableValue(weight))
|
||||
return computeOp.emitOpError() << kind << " weights must be statically computed from constants";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool isConstantIndexLike(Value value) {
|
||||
APInt constantValue;
|
||||
return matchPattern(value, m_ConstantInt(&constantValue));
|
||||
@@ -545,6 +555,8 @@ LogicalResult SpatCompute::verify() {
|
||||
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
|
||||
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
|
||||
return emitError("ComputeOp block argument is not used");
|
||||
if (failed(verifyStaticWeights(*this, "compute")))
|
||||
return failure();
|
||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
|
||||
return failure();
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
@@ -647,6 +659,8 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
if (failed(verifyStaticWeights(*this, "compute_batch")))
|
||||
return failure();
|
||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
|
||||
return failure();
|
||||
return verifyBatchBody(*this, block);
|
||||
|
||||
Reference in New Issue
Block a user