add verification of static weights in spatial
Validate Operations / validate-operations (push) Waiting to run
Validate Operations / validate-operations (push) Waiting to run
This commit is contained in:
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
@@ -116,6 +117,15 @@ static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
|||||||
return argNumber >= firstOutputArgNumber && argNumber < firstOutputArgNumber + batchOp.getNumResults();
|
return argNumber >= firstOutputArgNumber && argNumber < firstOutputArgNumber + batchOp.getNumResults();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename ComputeOpTy>
|
||||||
|
static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind) {
|
||||||
|
for (Value weight : computeOp.getWeights()) {
|
||||||
|
if (!isHostFoldableValue(weight))
|
||||||
|
return computeOp.emitOpError() << kind << " weights must be statically computed from constants";
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
static bool isConstantIndexLike(Value value) {
|
static bool isConstantIndexLike(Value value) {
|
||||||
APInt constantValue;
|
APInt constantValue;
|
||||||
return matchPattern(value, m_ConstantInt(&constantValue));
|
return matchPattern(value, m_ConstantInt(&constantValue));
|
||||||
@@ -545,6 +555,8 @@ LogicalResult SpatCompute::verify() {
|
|||||||
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
|
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
|
||||||
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
|
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
|
||||||
return emitError("ComputeOp block argument is not used");
|
return emitError("ComputeOp block argument is not used");
|
||||||
|
if (failed(verifyStaticWeights(*this, "compute")))
|
||||||
|
return failure();
|
||||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
|
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
|
||||||
return failure();
|
return failure();
|
||||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||||
@@ -647,6 +659,8 @@ LogicalResult SpatComputeBatch::verify() {
|
|||||||
|
|
||||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||||
return failure();
|
return failure();
|
||||||
|
if (failed(verifyStaticWeights(*this, "compute_batch")))
|
||||||
|
return failure();
|
||||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
|
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
|
||||||
return failure();
|
return failure();
|
||||||
return verifyBatchBody(*this, block);
|
return verifyBatchBody(*this, block);
|
||||||
|
|||||||
Reference in New Issue
Block a user