This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e95f24f866f Fix pytest collection failure for classes decorated with 
context managers (#55915)
e95f24f866f is described below

commit e95f24f866fb5d9c81b37fd8f2ce7247889656c1
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Sep 23 04:42:46 2025 +0100

    Fix pytest collection failure for classes decorated with context managers 
(#55915)
    
    Classes decorated with `@conf_vars` and other context managers were 
disappearing
    during pytest collection, causing tests to be silently skipped. This 
affected
    several test classes including `TestWorkerStart` in the Celery provider 
tests.
    
    Root cause: `ContextDecorator` transforms decorated classes into callable 
wrappers.
    Since pytest only collects actual type objects as test classes, these 
wrapped
    classes are ignored during collection.
    
    Simple reproduction (no Airflow needed):
    
    ```py
    import contextlib
    import inspect
    
    @contextlib.contextmanager
    def simple_cm():
        yield
    
    @simple_cm()
    class TestExample:
        def test_method(self):
            pass
    
    print(f'Is class? {inspect.isclass(TestExample)}')  # False - pytest won't 
collect
    ```
    
    and then run
    
    ```shell
    pytest test_example.py --collect-only
    ```
    
    Airflow reproduction:
    
    ```shell
    breeze run pytest 
providers/celery/tests/unit/celery/cli/test_celery_command.py --collect-only -v
    
    breeze run pytest 
providers/celery/tests/unit/celery/cli/test_celery_command.py --collect-only -v
    ```
    
    Solution:
    1. Fixed affected test files by replacing class-level `@conf_vars` 
decorators
       with pytest fixtures
    2. Created pytest fixtures to apply configuration changes
    3. Used `@pytest.mark.usefixtures` to apply configuration to test classes
    4. Added custom linter to prevent future occurrences and integrated it
       into pre-commit hooks
    
    Files changed:
    - Fixed 3 test files with problematic class decorators
    - Added custom linter with pre-commit integration
    
    This ensures pytest properly collects all test classes and prevents similar
    issues in the future through automated detection.
---
 .pre-commit-config.yaml                            |   6 +
 .../apache/kafka/operators/test_consume.py         |  41 ++++---
 .../apache/kafka/operators/test_produce.py         |  33 ++---
 .../tests/unit/celery/cli/test_celery_command.py   |  14 ++-
 .../prek/check_contextmanager_class_decorators.py  | 133 +++++++++++++++++++++
 5 files changed, 191 insertions(+), 36 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 5c74a9cad8e..91bfd6c3203 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1649,3 +1649,9 @@ repos:
         files: 
^airflow-core/src/airflow/serialization/schema\.json$|^airflow-core/src/airflow/serialization/serialized_objects\.py$
         pass_filenames: false
         require_serial: true
+      - id: check-contextmanager-class-decorators
+        name: Check for problematic context manager class decorators
+        entry: ./scripts/ci/prek/check_contextmanager_class_decorators.py
+        language: python
+        files: .*test.*\.py$
+        pass_filenames: true
diff --git 
a/providers/apache/kafka/tests/integration/apache/kafka/operators/test_consume.py
 
b/providers/apache/kafka/tests/integration/apache/kafka/operators/test_consume.py
index ddca5fc502e..aefcee1d753 100644
--- 
a/providers/apache/kafka/tests/integration/apache/kafka/operators/test_consume.py
+++ 
b/providers/apache/kafka/tests/integration/apache/kafka/operators/test_consume.py
@@ -25,10 +25,9 @@ import pytest
 from confluent_kafka import Producer
 
 # Import Operator
+from airflow.models.connection import Connection
 from airflow.providers.apache.kafka.operators.consume import 
ConsumeFromTopicOperator
 
-from tests_common.test_utils.config import conf_vars
-
 log = logging.getLogger(__name__)
 
 
@@ -49,23 +48,29 @@ def _basic_message_tester(message, test=None) -> Any:
     assert message.value().decode(encoding="utf-8") == test
 
 
[email protected](autouse=True)
+def kafka_consumer_connections(create_connection_without_db):
+    """Create Kafka consumer connections for testing purpose."""
+    connections = [
+        Connection(
+            conn_id="operator.consumer.test.integration.test_1",
+            
uri="kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_1&enable.auto.commit=False&auto.offset.reset=beginning",
+        ),
+        Connection(
+            conn_id="operator.consumer.test.integration.test_2",
+            
uri="kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_2&enable.auto.commit=False&auto.offset.reset=beginning",
+        ),
+        Connection(
+            conn_id="operator.consumer.test.integration.test_3",
+            
uri="kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_3&enable.auto.commit=False&auto.offset.reset=beginning",
+        ),
+    ]
+
+    for conn in connections:
+        create_connection_without_db(conn)
+
+
 @pytest.mark.integration("kafka")
-@conf_vars(
-    {
-        (
-            "connections",
-            "operator.consumer.test.integration.test_1",
-        ): 
"kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_1&enable.auto.commit=False&auto.offset.reset=beginning",
-        (
-            "connections",
-            "operator.consumer.test.integration.test_2",
-        ): 
"kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_2&enable.auto.commit=False&auto.offset.reset=beginning",
-        (
-            "connections",
-            "operator.consumer.test.integration.test_3",
-        ): 
"kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_3&enable.auto.commit=False&auto.offset.reset=beginning",
-    }
-)
 class TestConsumeFromTopic:
     """
     test ConsumeFromTopicOperator
diff --git 
a/providers/apache/kafka/tests/integration/apache/kafka/operators/test_produce.py
 
b/providers/apache/kafka/tests/integration/apache/kafka/operators/test_produce.py
index ada93900ce6..d7a5167916d 100644
--- 
a/providers/apache/kafka/tests/integration/apache/kafka/operators/test_produce.py
+++ 
b/providers/apache/kafka/tests/integration/apache/kafka/operators/test_produce.py
@@ -22,10 +22,9 @@ import logging
 import pytest
 from confluent_kafka import Consumer
 
+from airflow.models.connection import Connection
 from airflow.providers.apache.kafka.operators.produce import 
ProduceToTopicOperator
 
-from tests_common.test_utils.config import conf_vars
-
 log = logging.getLogger(__name__)
 
 
@@ -34,19 +33,25 @@ def _producer_function():
         yield (json.dumps(i), json.dumps(i + 1))
 
 
[email protected](autouse=True)
+def kafka_connections(create_connection_without_db):
+    """Create Kafka producer connections for testing purpose."""
+    connections = [
+        Connection(
+            conn_id="kafka_default_test_1",
+            
uri="kafka://broker:29092?socket.timeout.ms=10&message.timeout.ms=10&group.id=operator.producer.test.integration.test_1",
+        ),
+        Connection(
+            conn_id="kafka_default_test_2",
+            
uri="kafka://broker:29092?socket.timeout.ms=10&message.timeout.ms=10&group.id=operator.producer.test.integration.test_2",
+        ),
+    ]
+
+    for conn in connections:
+        create_connection_without_db(conn)
+
+
 @pytest.mark.integration("kafka")
-@conf_vars(
-    {
-        (
-            "connections",
-            "kafka_default_test_1",
-        ): 
"kafka://broker:29092?socket.timeout.ms=10&message.timeout.ms=10&group.id=operator.producer.test.integration.test_1",
-        (
-            "connections",
-            "kafka_default_test_2",
-        ): 
"kafka://broker:29092?socket.timeout.ms=10&message.timeout.ms=10&group.id=operator.producer.test.integration.test_2",
-    }
-)
 class TestProduceToTopic:
     """
     test ProduceToTopicOperator
diff --git a/providers/celery/tests/unit/celery/cli/test_celery_command.py 
b/providers/celery/tests/unit/celery/cli/test_celery_command.py
index 41f74c9c4a6..94e339804e3 100644
--- a/providers/celery/tests/unit/celery/cli/test_celery_command.py
+++ b/providers/celery/tests/unit/celery/cli/test_celery_command.py
@@ -37,8 +37,14 @@ from tests_common.test_utils.config import conf_vars
 from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
 
 
[email protected](autouse=False)
+def conf_stale_bundle_cleanup_disabled():
+    with conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): "0"}):
+        yield
+
+
 @pytest.mark.backend("mysql", "postgres")
-@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
[email protected]("conf_stale_bundle_cleanup_disabled")
 class TestCeleryStopCommand:
     @classmethod
     def setup_class(cls):
@@ -120,7 +126,7 @@ class TestCeleryStopCommand:
 
 
 @pytest.mark.backend("mysql", "postgres")
-@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
[email protected]("conf_stale_bundle_cleanup_disabled")
 class TestWorkerStart:
     @classmethod
     def setup_class(cls):
@@ -181,7 +187,7 @@ class TestWorkerStart:
 
 
 @pytest.mark.backend("mysql", "postgres")
-@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
[email protected]("conf_stale_bundle_cleanup_disabled")
 class TestWorkerFailure:
     @classmethod
     def setup_class(cls):
@@ -201,7 +207,7 @@ class TestWorkerFailure:
 
 
 @pytest.mark.backend("mysql", "postgres")
-@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
[email protected]("conf_stale_bundle_cleanup_disabled")
 class TestFlowerCommand:
     @classmethod
     def setup_class(cls):
diff --git a/scripts/ci/prek/check_contextmanager_class_decorators.py 
b/scripts/ci/prek/check_contextmanager_class_decorators.py
new file mode 100644
index 00000000000..149d00bc6c2
--- /dev/null
+++ b/scripts/ci/prek/check_contextmanager_class_decorators.py
@@ -0,0 +1,133 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Check for problematic context manager decorators on test classes.
+
+Context managers (ContextDecorator, @contextlib.contextmanager) when used as 
class decorators
+transform the class into a callable wrapper, which prevents pytest from 
collecting the class.
+"""
+
+from __future__ import annotations
+
+import ast
+import sys
+from pathlib import Path
+
+
+class ContextManagerClassDecoratorChecker(ast.NodeVisitor):
+    """AST visitor to check for context manager decorators on test classes."""
+
+    def __init__(self, filename: str):
+        self.filename = filename
+        self.errors: list[str] = []
+
+    def visit_ClassDef(self, node: ast.ClassDef) -> None:
+        """Check class definitions for problematic decorators."""
+        if not node.name.startswith("Test"):
+            self.generic_visit(node)
+            return
+
+        for decorator in node.decorator_list:
+            decorator_name = self._get_decorator_name(decorator)
+            if self._is_problematic_decorator(decorator_name):
+                self.errors.append(
+                    f"{self.filename}:{node.lineno}: Class '{node.name}' uses 
@{decorator_name} "
+                    f"decorator which prevents pytest collection. Use 
@pytest.mark.usefixtures instead."
+                )
+
+        self.generic_visit(node)
+
+    def _get_decorator_name(self, decorator: ast.expr) -> str:
+        """Extract decorator name from AST node."""
+        if isinstance(decorator, ast.Name):
+            return decorator.id
+        if isinstance(decorator, ast.Call):
+            if isinstance(decorator.func, ast.Name):
+                return decorator.func.id
+            if isinstance(decorator.func, ast.Attribute):
+                return f"{self._get_attr_chain(decorator.func)}"
+        elif isinstance(decorator, ast.Attribute):
+            return f"{self._get_attr_chain(decorator)}"
+        return "unknown"
+
+    def _get_attr_chain(self, node: ast.Attribute) -> str:
+        """Get the full attribute chain (e.g., 'contextlib.contextmanager')."""
+        if isinstance(node.value, ast.Name):
+            return f"{node.value.id}.{node.attr}"
+        if isinstance(node.value, ast.Attribute):
+            return f"{self._get_attr_chain(node.value)}.{node.attr}"
+        return node.attr
+
+    def _is_problematic_decorator(self, decorator_name: str) -> bool:
+        """Check if decorator is known to break pytest class collection."""
+        problematic_decorators = {
+            "conf_vars",
+            "env_vars",
+            "contextlib.contextmanager",
+            "contextmanager",
+        }
+        return decorator_name in problematic_decorators
+
+
+def check_file(filepath: Path) -> list[str]:
+    """Check a single file for problematic decorators."""
+    try:
+        with open(filepath, encoding="utf-8") as f:
+            content = f.read()
+
+        tree = ast.parse(content, filename=str(filepath))
+        checker = ContextManagerClassDecoratorChecker(str(filepath))
+        checker.visit(tree)
+        return checker.errors
+    except Exception as e:
+        return [f"{filepath}: Error parsing file: {e}"]
+
+
+def main() -> int:
+    """Main entry point."""
+    if len(sys.argv) < 2:
+        print("Usage: check_contextmanager_class_decorators.py 
<file_or_directory>...")
+        return 1
+
+    all_errors = []
+
+    for arg in sys.argv[1:]:
+        path = Path(arg)
+        if path.is_file() and path.suffix == ".py":
+            if "test" in str(path):  # Only check test files
+                all_errors.extend(check_file(path))
+            else:
+                print(f"Skipping non-test file: {path}")
+        elif path.is_dir():
+            for py_file in path.rglob("*.py"):
+                if "test" in str(py_file):  # Only check test files
+                    all_errors.extend(check_file(py_file))
+
+    if all_errors:
+        print("Found problematic context manager class decorators:")
+        for error in all_errors:
+            print(f"  {error}")
+        return 1
+
+    return 0
+
+
+if __name__ == "__main__":
+    sys.exit(main())

Reply via email to