remove useless MaterializeHostConstantsPass.cpp and fix lowering before instead
Validate Operations / validate-operations (push) Has been cancelled

avoid spammy pim codegen diagnostics
This commit is contained in:
NiccoloN
2026-06-05 10:06:28 +02:00
parent 27410207c4
commit 1e9e61f5a9
20 changed files with 458 additions and 256 deletions
+62 -18
View File
@@ -47,6 +47,16 @@ CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
}
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefTypeStrides(mlir::MemRefType type) {
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
if (failed(type.getStridesAndOffset(strides, offset)))
return mlir::failure();
if (llvm::is_contained(strides, mlir::ShapedType::kDynamic))
return mlir::failure();
return strides;
}
template <typename VMMOpTy, typename ParentOpTy>
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
auto weightArg = parentOp.getWeightArgument(weightIndex);
@@ -162,6 +172,11 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
mlir::Value current = weight;
while (true) {
if (mlir::Value directAlias = knowledge.aliases.lookup(current); directAlias && directAlias != current) {
current = directAlias;
continue;
}
if (auto defOp = current.getDefiningOp()) {
if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) {
auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {};
@@ -181,8 +196,6 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
CompiledIndexExpr offsetExpr = makeConstantExpr(0);
for (mlir::Operation* viewOp : llvm::reverse(viewOps)) {
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) {
llvm::SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getMixedOffsets().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) {
CompiledIndexExpr offsetValue = makeConstantExpr(0);
@@ -202,29 +215,47 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
return mlir::failure();
}
offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride));
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
auto resultType = mlir::cast<mlir::MemRefType>(subview.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = std::move(*resultStrides);
continue;
}
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return mlir::failure();
auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
view.strides = std::move(*resultStrides);
continue;
}
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return mlir::failure();
auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
view.strides = std::move(*resultStrides);
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(viewOp)) {
auto resultType = mlir::cast<mlir::MemRefType>(castOp.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = std::move(*resultStrides);
continue;
}
return mlir::failure();
}
auto resolvedOffset = offsetExpr.evaluate(knowledge);
@@ -234,18 +265,26 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
return view;
}
if (mlir::isa<mlir::memref::SubViewOp, mlir::memref::CollapseShapeOp, mlir::memref::ExpandShapeOp>(defOp)) {
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp)) {
viewOps.push_back(defOp);
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp))
current = subview.getSource();
else if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp))
current = collapse.getSrc();
else
current = mlir::cast<mlir::memref::ExpandShapeOp>(defOp).getSrc();
current = subview.getSource();
continue;
}
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp)) {
viewOps.push_back(defOp);
current = collapse.getSrc();
continue;
}
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(defOp)) {
viewOps.push_back(defOp);
current = expand.getSrc();
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) {
viewOps.push_back(defOp);
current = castOp.getSource();
continue;
}
@@ -253,6 +292,11 @@ resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const Static
return mlir::failure();
}
if (mlir::Value loopAlias = resolveLoopCarriedAlias(current, knowledge); loopAlias && loopAlias != current) {
current = loopAlias;
continue;
}
auto weightIndex = resolveWeightIndex(weightOwner, current);
if (!weightIndex)
return mlir::failure();
+2
View File
@@ -28,6 +28,8 @@ struct CappedDiagnosticReporter {
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
}
void noteFailures(int64_t count) { numFailures += count; }
bool hasFailure() const { return numFailures != 0; }
private: