This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e95f24f866f Fix pytest collection failure for classes decorated with
context managers (#55915)
e95f24f866f is described below
commit e95f24f866fb5d9c81b37fd8f2ce7247889656c1
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Sep 23 04:42:46 2025 +0100
Fix pytest collection failure for classes decorated with context managers
(#55915)
Classes decorated with `@conf_vars` and other context managers were
disappearing
during pytest collection, causing tests to be silently skipped. This
affected
several test classes including `TestWorkerStart` in the Celery provider
tests.
Root cause: `ContextDecorator` transforms decorated classes into callable
wrappers.
Since pytest only collects actual type objects as test classes, these
wrapped
classes are ignored during collection.
Simple reproduction (no Airflow needed):
```py
import contextlib
import inspect
@contextlib.contextmanager
def simple_cm():
yield
@simple_cm()
class TestExample:
def test_method(self):
pass
print(f'Is class? {inspect.isclass(TestExample)}') # False - pytest won't
collect
```
and then run
```shell
pytest test_example.py --collect-only
```
Airflow reproduction:
```shell
breeze run pytest
providers/celery/tests/unit/celery/cli/test_celery_command.py --collect-only -v
breeze run pytest
providers/celery/tests/unit/celery/cli/test_celery_command.py --collect-only -v
```
Solution:
1. Fixed affected test files by replacing class-level `@conf_vars`
decorators
with pytest fixtures
2. Created pytest fixtures to apply configuration changes
3. Used `@pytest.mark.usefixtures` to apply configuration to test classes
4. Added custom linter to prevent future occurrences and integrated it
into pre-commit hooks
Files changed:
- Fixed 3 test files with problematic class decorators
- Added custom linter with pre-commit integration
This ensures pytest properly collects all test classes and prevents similar
issues in the future through automated detection.
---
.pre-commit-config.yaml | 6 +
.../apache/kafka/operators/test_consume.py | 41 ++++---
.../apache/kafka/operators/test_produce.py | 33 ++---
.../tests/unit/celery/cli/test_celery_command.py | 14 ++-
.../prek/check_contextmanager_class_decorators.py | 133 +++++++++++++++++++++
5 files changed, 191 insertions(+), 36 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 5c74a9cad8e..91bfd6c3203 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1649,3 +1649,9 @@ repos:
files:
^airflow-core/src/airflow/serialization/schema\.json$|^airflow-core/src/airflow/serialization/serialized_objects\.py$
pass_filenames: false
require_serial: true
+ - id: check-contextmanager-class-decorators
+ name: Check for problematic context manager class decorators
+ entry: ./scripts/ci/prek/check_contextmanager_class_decorators.py
+ language: python
+ files: .*test.*\.py$
+ pass_filenames: true
diff --git
a/providers/apache/kafka/tests/integration/apache/kafka/operators/test_consume.py
b/providers/apache/kafka/tests/integration/apache/kafka/operators/test_consume.py
index ddca5fc502e..aefcee1d753 100644
---
a/providers/apache/kafka/tests/integration/apache/kafka/operators/test_consume.py
+++
b/providers/apache/kafka/tests/integration/apache/kafka/operators/test_consume.py
@@ -25,10 +25,9 @@ import pytest
from confluent_kafka import Producer
# Import Operator
+from airflow.models.connection import Connection
from airflow.providers.apache.kafka.operators.consume import
ConsumeFromTopicOperator
-from tests_common.test_utils.config import conf_vars
-
log = logging.getLogger(__name__)
@@ -49,23 +48,29 @@ def _basic_message_tester(message, test=None) -> Any:
assert message.value().decode(encoding="utf-8") == test
[email protected](autouse=True)
+def kafka_consumer_connections(create_connection_without_db):
+ """Create Kafka consumer connections for testing purpose."""
+ connections = [
+ Connection(
+ conn_id="operator.consumer.test.integration.test_1",
+
uri="kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_1&enable.auto.commit=False&auto.offset.reset=beginning",
+ ),
+ Connection(
+ conn_id="operator.consumer.test.integration.test_2",
+
uri="kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_2&enable.auto.commit=False&auto.offset.reset=beginning",
+ ),
+ Connection(
+ conn_id="operator.consumer.test.integration.test_3",
+
uri="kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_3&enable.auto.commit=False&auto.offset.reset=beginning",
+ ),
+ ]
+
+ for conn in connections:
+ create_connection_without_db(conn)
+
+
@pytest.mark.integration("kafka")
-@conf_vars(
- {
- (
- "connections",
- "operator.consumer.test.integration.test_1",
- ):
"kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_1&enable.auto.commit=False&auto.offset.reset=beginning",
- (
- "connections",
- "operator.consumer.test.integration.test_2",
- ):
"kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_2&enable.auto.commit=False&auto.offset.reset=beginning",
- (
- "connections",
- "operator.consumer.test.integration.test_3",
- ):
"kafka://broker:29092?socket.timeout.ms=10&bootstrap.servers=broker:29092&group.id=operator.consumer.test.integration.test_3&enable.auto.commit=False&auto.offset.reset=beginning",
- }
-)
class TestConsumeFromTopic:
"""
test ConsumeFromTopicOperator
diff --git
a/providers/apache/kafka/tests/integration/apache/kafka/operators/test_produce.py
b/providers/apache/kafka/tests/integration/apache/kafka/operators/test_produce.py
index ada93900ce6..d7a5167916d 100644
---
a/providers/apache/kafka/tests/integration/apache/kafka/operators/test_produce.py
+++
b/providers/apache/kafka/tests/integration/apache/kafka/operators/test_produce.py
@@ -22,10 +22,9 @@ import logging
import pytest
from confluent_kafka import Consumer
+from airflow.models.connection import Connection
from airflow.providers.apache.kafka.operators.produce import
ProduceToTopicOperator
-from tests_common.test_utils.config import conf_vars
-
log = logging.getLogger(__name__)
@@ -34,19 +33,25 @@ def _producer_function():
yield (json.dumps(i), json.dumps(i + 1))
[email protected](autouse=True)
+def kafka_connections(create_connection_without_db):
+ """Create Kafka producer connections for testing purpose."""
+ connections = [
+ Connection(
+ conn_id="kafka_default_test_1",
+
uri="kafka://broker:29092?socket.timeout.ms=10&message.timeout.ms=10&group.id=operator.producer.test.integration.test_1",
+ ),
+ Connection(
+ conn_id="kafka_default_test_2",
+
uri="kafka://broker:29092?socket.timeout.ms=10&message.timeout.ms=10&group.id=operator.producer.test.integration.test_2",
+ ),
+ ]
+
+ for conn in connections:
+ create_connection_without_db(conn)
+
+
@pytest.mark.integration("kafka")
-@conf_vars(
- {
- (
- "connections",
- "kafka_default_test_1",
- ):
"kafka://broker:29092?socket.timeout.ms=10&message.timeout.ms=10&group.id=operator.producer.test.integration.test_1",
- (
- "connections",
- "kafka_default_test_2",
- ):
"kafka://broker:29092?socket.timeout.ms=10&message.timeout.ms=10&group.id=operator.producer.test.integration.test_2",
- }
-)
class TestProduceToTopic:
"""
test ProduceToTopicOperator
diff --git a/providers/celery/tests/unit/celery/cli/test_celery_command.py
b/providers/celery/tests/unit/celery/cli/test_celery_command.py
index 41f74c9c4a6..94e339804e3 100644
--- a/providers/celery/tests/unit/celery/cli/test_celery_command.py
+++ b/providers/celery/tests/unit/celery/cli/test_celery_command.py
@@ -37,8 +37,14 @@ from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
[email protected](autouse=False)
+def conf_stale_bundle_cleanup_disabled():
+ with conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): "0"}):
+ yield
+
+
@pytest.mark.backend("mysql", "postgres")
-@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
[email protected]("conf_stale_bundle_cleanup_disabled")
class TestCeleryStopCommand:
@classmethod
def setup_class(cls):
@@ -120,7 +126,7 @@ class TestCeleryStopCommand:
@pytest.mark.backend("mysql", "postgres")
-@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
[email protected]("conf_stale_bundle_cleanup_disabled")
class TestWorkerStart:
@classmethod
def setup_class(cls):
@@ -181,7 +187,7 @@ class TestWorkerStart:
@pytest.mark.backend("mysql", "postgres")
-@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
[email protected]("conf_stale_bundle_cleanup_disabled")
class TestWorkerFailure:
@classmethod
def setup_class(cls):
@@ -201,7 +207,7 @@ class TestWorkerFailure:
@pytest.mark.backend("mysql", "postgres")
-@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
[email protected]("conf_stale_bundle_cleanup_disabled")
class TestFlowerCommand:
@classmethod
def setup_class(cls):
diff --git a/scripts/ci/prek/check_contextmanager_class_decorators.py
b/scripts/ci/prek/check_contextmanager_class_decorators.py
new file mode 100644
index 00000000000..149d00bc6c2
--- /dev/null
+++ b/scripts/ci/prek/check_contextmanager_class_decorators.py
@@ -0,0 +1,133 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Check for problematic context manager decorators on test classes.
+
+Context managers (ContextDecorator, @contextlib.contextmanager) when used as
class decorators
+transform the class into a callable wrapper, which prevents pytest from
collecting the class.
+"""
+
+from __future__ import annotations
+
+import ast
+import sys
+from pathlib import Path
+
+
+class ContextManagerClassDecoratorChecker(ast.NodeVisitor):
+ """AST visitor to check for context manager decorators on test classes."""
+
+ def __init__(self, filename: str):
+ self.filename = filename
+ self.errors: list[str] = []
+
+ def visit_ClassDef(self, node: ast.ClassDef) -> None:
+ """Check class definitions for problematic decorators."""
+ if not node.name.startswith("Test"):
+ self.generic_visit(node)
+ return
+
+ for decorator in node.decorator_list:
+ decorator_name = self._get_decorator_name(decorator)
+ if self._is_problematic_decorator(decorator_name):
+ self.errors.append(
+ f"{self.filename}:{node.lineno}: Class '{node.name}' uses
@{decorator_name} "
+ f"decorator which prevents pytest collection. Use
@pytest.mark.usefixtures instead."
+ )
+
+ self.generic_visit(node)
+
+ def _get_decorator_name(self, decorator: ast.expr) -> str:
+ """Extract decorator name from AST node."""
+ if isinstance(decorator, ast.Name):
+ return decorator.id
+ if isinstance(decorator, ast.Call):
+ if isinstance(decorator.func, ast.Name):
+ return decorator.func.id
+ if isinstance(decorator.func, ast.Attribute):
+ return f"{self._get_attr_chain(decorator.func)}"
+ elif isinstance(decorator, ast.Attribute):
+ return f"{self._get_attr_chain(decorator)}"
+ return "unknown"
+
+ def _get_attr_chain(self, node: ast.Attribute) -> str:
+ """Get the full attribute chain (e.g., 'contextlib.contextmanager')."""
+ if isinstance(node.value, ast.Name):
+ return f"{node.value.id}.{node.attr}"
+ if isinstance(node.value, ast.Attribute):
+ return f"{self._get_attr_chain(node.value)}.{node.attr}"
+ return node.attr
+
+ def _is_problematic_decorator(self, decorator_name: str) -> bool:
+ """Check if decorator is known to break pytest class collection."""
+ problematic_decorators = {
+ "conf_vars",
+ "env_vars",
+ "contextlib.contextmanager",
+ "contextmanager",
+ }
+ return decorator_name in problematic_decorators
+
+
+def check_file(filepath: Path) -> list[str]:
+ """Check a single file for problematic decorators."""
+ try:
+ with open(filepath, encoding="utf-8") as f:
+ content = f.read()
+
+ tree = ast.parse(content, filename=str(filepath))
+ checker = ContextManagerClassDecoratorChecker(str(filepath))
+ checker.visit(tree)
+ return checker.errors
+ except Exception as e:
+ return [f"{filepath}: Error parsing file: {e}"]
+
+
+def main() -> int:
+ """Main entry point."""
+ if len(sys.argv) < 2:
+ print("Usage: check_contextmanager_class_decorators.py
<file_or_directory>...")
+ return 1
+
+ all_errors = []
+
+ for arg in sys.argv[1:]:
+ path = Path(arg)
+ if path.is_file() and path.suffix == ".py":
+ if "test" in str(path): # Only check test files
+ all_errors.extend(check_file(path))
+ else:
+ print(f"Skipping non-test file: {path}")
+ elif path.is_dir():
+ for py_file in path.rglob("*.py"):
+ if "test" in str(py_file): # Only check test files
+ all_errors.extend(check_file(py_file))
+
+ if all_errors:
+ print("Found problematic context manager class decorators:")
+ for error in all_errors:
+ print(f" {error}")
+ return 1
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())