From 4855a2e105c294512b8ab8352e709176b51ce17b Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Sun, 24 May 2026 12:00:42 +0200 Subject: [PATCH] add verification of static weights in spatial --- src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index 10c7591..2fbf0e2 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -15,6 +15,7 @@ #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -116,6 +117,15 @@ static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) { return argNumber >= firstOutputArgNumber && argNumber < firstOutputArgNumber + batchOp.getNumResults(); } +template +static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind) { + for (Value weight : computeOp.getWeights()) { + if (!isHostFoldableValue(weight)) + return computeOp.emitOpError() << kind << " weights must be statically computed from constants"; + } + return success(); +} + static bool isConstantIndexLike(Value value) { APInt constantValue; return matchPattern(value, m_ConstantInt(&constantValue)); @@ -545,6 +555,8 @@ LogicalResult SpatCompute::verify() { for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex) if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty()) return emitError("ComputeOp block argument is not used"); + if (failed(verifyStaticWeights(*this, "compute"))) + return failure(); if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute"))) return failure(); if (failed(verifyComputeResultsUses(this->getOperation()))) @@ -647,6 +659,8 @@ LogicalResult SpatComputeBatch::verify() { if (failed(verifyComputeResultsUses(this->getOperation()))) return failure(); + if (failed(verifyStaticWeights(*this, "compute_batch"))) + return failure(); if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch"))) return failure(); return verifyBatchBody(*this, block);