No extract no more
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-25 18:19:43 +02:00
parent b79c333c6c
commit bdc4ca33f3
9 changed files with 160 additions and 81 deletions
@@ -7,6 +7,9 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include <utility>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -195,66 +198,90 @@ static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm:
return nullptr;
}
static bool isCompileTimeOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
static std::optional<CompileTimeSource>
getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) {
if (!op)
return false;
return std::nullopt;
if (!visited.insert(op).second)
return true;
return {
{op, chainLength}
};
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
return true;
return {
{op, chainLength}
};
chainLength += 1;
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
return hasConstantIndices(extractOp) && isCompileTimeOpImpl(extractOp, visited);
return hasConstantIndices(extractOp) ? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength) : std::nullopt;
if (!isStaticTensorResult(op))
return false;
return std::nullopt;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
return isCompileTimeOpImpl(transposeOp, visited);
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
return isCompileTimeOpImpl(collapseShapeOp,visited);
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
return isCompileTimeOpImpl(expandShapeOp, visited);
return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength);
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp) && isCompileTimeOpImpl(extractSliceOp, visited);
return hasStaticUnitStrides(extractSliceOp) ? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
: std::nullopt;
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
return isCompileTimeOpImpl(splatOp, visited);
return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength);
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
return isCompileTimeOpImpl(extractRowsOp, visited);
return getCompileTimeSourceImpl(extractRowsOp.getInput().getDefiningOp(), visited, chainLength);
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)){
bool res = true;
for(auto operandValue : concatOp.getOperands()){
res &= isCompileTimeOpImpl(operandValue.getDefiningOp(), visited);
if(!res) break;
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
std::optional<CompileTimeSource> res = {};
for (auto operandValue : concatOp.getOperands()) {
auto partialRes = getCompileTimeSourceImpl(operandValue.getDefiningOp(), visited, chainLength);
if (!partialRes)
return std::nullopt;
if (!res) {
res = partialRes;
continue;
}
if(res->chainLength < partialRes->chainLength){
res = partialRes;
}
}
return res;
}
return false;
return std::nullopt;
}
} // namespace
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getCompileTimeSourceImpl(op, visited);
}
bool isCompileTimeComputable(Value value) {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
return isCompileTimeOpImpl(definingOp, visited);
return getCompileTimeSourceImpl(definingOp, visited).has_value();
}
bool isCompileTimeOp(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return isCompileTimeOpImpl(op, visited);
return getCompileTimeSourceImpl(op, visited).has_value();
}
DenseElementsAttr getHostConstDenseElementsAttr(Value value) {