This is an automated email from the ASF dual-hosted git repository.

shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new da553935d2 Send context using in venv operator (#41039)
da553935d2 is described below

commit da553935d248f22695124c40777d3ea29e04d57f
Author: phi-friday <[email protected]>
AuthorDate: Fri Aug 9 02:25:09 2024 +0900

    Send context using in venv operator (#41039)
---
 airflow/decorators/__init__.pyi                    |   6 +
 .../example_python_context_decorator.py            |  92 ++++++++++++++++
 .../example_python_context_operator.py             |  91 +++++++++++++++
 airflow/operators/python.py                        |  36 ++++++
 airflow/utils/python_virtualenv_script.jinja2      |  23 ++++
 docs/apache-airflow/howto/operator/python.rst      |  92 ++++++++++++++++
 newsfragments/41039.feature.rst                    |   1 +
 tests/operators/test_python.py                     | 122 ++++++++++++++++++++-
 8 files changed, 462 insertions(+), 1 deletion(-)

diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi
index faf77e8240..089e453d02 100644
--- a/airflow/decorators/__init__.pyi
+++ b/airflow/decorators/__init__.pyi
@@ -125,6 +125,7 @@ class TaskDecoratorCollection:
         env_vars: dict[str, str] | None = None,
         inherit_env: bool = True,
         use_dill: bool = False,
+        use_airflow_context: bool = False,
         **kwargs,
     ) -> TaskDecorator:
         """Create a decorator to convert the decorated callable to a virtual 
environment task.
@@ -176,6 +177,7 @@ class TaskDecoratorCollection:
         :param use_dill: Deprecated, use ``serializer`` instead. Whether to 
use dill to serialize
             the args and result (pickle is default). This allows more complex 
types
             but requires you to include dill in your requirements.
+        :param use_airflow_context: Whether to provide 
``get_current_context()`` to the python_callable.
         """
     @overload
     def virtualenv(self, python_callable: Callable[FParams, FReturn]) -> 
Task[FParams, FReturn]: ...
@@ -192,6 +194,7 @@ class TaskDecoratorCollection:
         env_vars: dict[str, str] | None = None,
         inherit_env: bool = True,
         use_dill: bool = False,
+        use_airflow_context: bool = False,
         **kwargs,
     ) -> TaskDecorator:
         """Create a decorator to convert the decorated callable to a virtual 
environment task.
@@ -225,6 +228,7 @@ class TaskDecoratorCollection:
         :param use_dill: Deprecated, use ``serializer`` instead. Whether to 
use dill to serialize
             the args and result (pickle is default). This allows more complex 
types
             but requires you to include dill in your requirements.
+        :param use_airflow_context: Whether to provide 
``get_current_context()`` to the python_callable.
         """
     @overload
     def branch(  # type: ignore[misc]
@@ -258,6 +262,7 @@ class TaskDecoratorCollection:
         venv_cache_path: None | str = None,
         show_return_value_in_logs: bool = True,
         use_dill: bool = False,
+        use_airflow_context: bool = False,
         **kwargs,
     ) -> TaskDecorator:
         """Create a decorator to wrap the decorated callable into a 
BranchPythonVirtualenvOperator.
@@ -299,6 +304,7 @@ class TaskDecoratorCollection:
         :param use_dill: Deprecated, use ``serializer`` instead. Whether to 
use dill to serialize
             the args and result (pickle is default). This allows more complex 
types
             but requires you to include dill in your requirements.
+        :param use_airflow_context: Whether to provide 
``get_current_context()`` to the python_callable.
         """
     @overload
     def branch_virtualenv(self, python_callable: Callable[FParams, FReturn]) 
-> Task[FParams, FReturn]: ...
diff --git a/airflow/example_dags/example_python_context_decorator.py 
b/airflow/example_dags/example_python_context_decorator.py
new file mode 100644
index 0000000000..497ee08e17
--- /dev/null
+++ b/airflow/example_dags/example_python_context_decorator.py
@@ -0,0 +1,92 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example DAG demonstrating the usage of the PythonOperator with 
`get_current_context()` to get the current context.
+
+Also, demonstrates the usage of the TaskFlow API.
+"""
+
+from __future__ import annotations
+
+import sys
+
+import pendulum
+
+from airflow.decorators import dag, task
+
+SOME_EXTERNAL_PYTHON = sys.executable
+
+
+@dag(
+    schedule=None,
+    start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+    catchup=False,
+    tags=["example"],
+)
+def example_python_context_decorator():
+    # [START get_current_context]
+    @task(task_id="print_the_context")
+    def print_context() -> str:
+        """Print the Airflow context."""
+        from pprint import pprint
+
+        from airflow.operators.python import get_current_context
+
+        context = get_current_context()
+        pprint(context)
+        return "Whatever you return gets printed in the logs"
+
+    print_the_context = print_context()
+    # [END get_current_context]
+
+    # [START get_current_context_venv]
+    @task.virtualenv(task_id="print_the_context_venv", 
use_airflow_context=True)
+    def print_context_venv() -> str:
+        """Print the Airflow context in venv."""
+        from pprint import pprint
+
+        from airflow.operators.python import get_current_context
+
+        context = get_current_context()
+        pprint(context)
+        return "Whatever you return gets printed in the logs"
+
+    print_the_context_venv = print_context_venv()
+    # [END get_current_context_venv]
+
+    # [START get_current_context_external]
+    @task.external_python(
+        task_id="print_the_context_external", python=SOME_EXTERNAL_PYTHON, 
use_airflow_context=True
+    )
+    def print_context_external() -> str:
+        """Print the Airflow context in external python."""
+        from pprint import pprint
+
+        from airflow.operators.python import get_current_context
+
+        context = get_current_context()
+        pprint(context)
+        return "Whatever you return gets printed in the logs"
+
+    print_the_context_external = print_context_external()
+    # [END get_current_context_external]
+
+    _ = print_the_context >> [print_the_context_venv, 
print_the_context_external]
+
+
+example_python_context_decorator()
diff --git a/airflow/example_dags/example_python_context_operator.py 
b/airflow/example_dags/example_python_context_operator.py
new file mode 100644
index 0000000000..f1b76c527c
--- /dev/null
+++ b/airflow/example_dags/example_python_context_operator.py
@@ -0,0 +1,91 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example DAG demonstrating the usage of the PythonOperator with 
`get_current_context()` to get the current context.
+
+Also, demonstrates the usage of the classic Python operators.
+"""
+
+from __future__ import annotations
+
+import sys
+
+import pendulum
+
+from airflow import DAG
+from airflow.operators.python import ExternalPythonOperator, PythonOperator, 
PythonVirtualenvOperator
+
+SOME_EXTERNAL_PYTHON = sys.executable
+
+with DAG(
+    dag_id="example_python_context_operator",
+    schedule=None,
+    start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+    catchup=False,
+    tags=["example"],
+) as dag:
+    # [START get_current_context]
+    def print_context() -> str:
+        """Print the Airflow context."""
+        from pprint import pprint
+
+        from airflow.operators.python import get_current_context
+
+        context = get_current_context()
+        pprint(context)
+        return "Whatever you return gets printed in the logs"
+
+    print_the_context = PythonOperator(task_id="print_the_context", 
python_callable=print_context)
+    # [END get_current_context]
+
+    # [START get_current_context_venv]
+    def print_context_venv() -> str:
+        """Print the Airflow context in venv."""
+        from pprint import pprint
+
+        from airflow.operators.python import get_current_context
+
+        context = get_current_context()
+        pprint(context)
+        return "Whatever you return gets printed in the logs"
+
+    print_the_context_venv = PythonVirtualenvOperator(
+        task_id="print_the_context_venv", python_callable=print_context_venv, 
use_airflow_context=True
+    )
+    # [END get_current_context_venv]
+
+    # [START get_current_context_external]
+    def print_context_external() -> str:
+        """Print the Airflow context in external python."""
+        from pprint import pprint
+
+        from airflow.operators.python import get_current_context
+
+        context = get_current_context()
+        pprint(context)
+        return "Whatever you return gets printed in the logs"
+
+    print_the_context_external = ExternalPythonOperator(
+        task_id="print_the_context_external",
+        python_callable=print_context_external,
+        python=SOME_EXTERNAL_PYTHON,
+        use_airflow_context=True,
+    )
+    # [END get_current_context_external]
+
+    _ = print_the_context >> [print_the_context_venv, 
print_the_context_external]
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index fdfe575fb9..ce6ccd3a40 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -56,12 +56,14 @@ from airflow.utils.file import get_unique_dag_module_name
 from airflow.utils.operator_helpers import ExecutionCallableRunner, 
KeywordParameters
 from airflow.utils.process_utils import execute_in_subprocess
 from airflow.utils.python_virtualenv import prepare_virtualenv, 
write_python_script
+from airflow.utils.session import create_session
 
 log = logging.getLogger(__name__)
 
 if TYPE_CHECKING:
     from pendulum.datetime import DateTime
 
+    from airflow.serialization.enums import Encoding
     from airflow.utils.context import Context
 
 
@@ -442,6 +444,7 @@ class _BasePythonVirtualenvOperator(PythonOperator, 
metaclass=ABCMeta):
         env_vars: dict[str, str] | None = None,
         inherit_env: bool = True,
         use_dill: bool = False,
+        use_airflow_context: bool = False,
         **kwargs,
     ):
         if (
@@ -494,6 +497,7 @@ class _BasePythonVirtualenvOperator(PythonOperator, 
metaclass=ABCMeta):
         )
         self.env_vars = env_vars
         self.inherit_env = inherit_env
+        self.use_airflow_context = use_airflow_context
 
     @abstractmethod
     def _iter_serializable_context_keys(self):
@@ -540,6 +544,7 @@ class _BasePythonVirtualenvOperator(PythonOperator, 
metaclass=ABCMeta):
             string_args_path = tmp_dir / "string_args.txt"
             script_path = tmp_dir / "script.py"
             termination_log_path = tmp_dir / "termination.log"
+            airflow_context_path = tmp_dir / "airflow_context.json"
 
             self._write_args(input_path)
             self._write_string_args(string_args_path)
@@ -551,6 +556,7 @@ class _BasePythonVirtualenvOperator(PythonOperator, 
metaclass=ABCMeta):
                 "pickling_library": self.serializer,
                 "python_callable": self.python_callable.__name__,
                 "python_callable_source": self.get_python_source(),
+                "use_airflow_context": self.use_airflow_context,
             }
 
             if inspect.getfile(self.python_callable) == self.dag.fileloc:
@@ -561,6 +567,23 @@ class _BasePythonVirtualenvOperator(PythonOperator, 
metaclass=ABCMeta):
                 filename=os.fspath(script_path),
                 
render_template_as_native_obj=self.dag.render_template_as_native_obj,
             )
+            if self.use_airflow_context:
+                from airflow.serialization.serialized_objects import 
BaseSerialization
+
+                context = get_current_context()
+                # TODO: `TaskInstance`` will also soon be serialized as 
expected.
+                # see more:
+                #   https://github.com/apache/airflow/issues/40974
+                #   https://github.com/apache/airflow/pull/41067
+                with create_session() as session:
+                    # FIXME: DetachedInstanceError
+                    dag_run, task_instance = context["dag_run"], 
context["task_instance"]
+                    session.add_all([dag_run, task_instance])
+                    serializable_context: dict[Encoding, Any] = 
BaseSerialization.serialize(
+                        context, use_pydantic_models=True
+                    )
+                with airflow_context_path.open("w+") as file:
+                    json.dump(serializable_context, file)
 
             env_vars = dict(os.environ) if self.inherit_env else {}
             if self.env_vars:
@@ -575,6 +598,7 @@ class _BasePythonVirtualenvOperator(PythonOperator, 
metaclass=ABCMeta):
                         os.fspath(output_path),
                         os.fspath(string_args_path),
                         os.fspath(termination_log_path),
+                        os.fspath(airflow_context_path),
                     ],
                     env=env_vars,
                 )
@@ -666,6 +690,7 @@ class 
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
     :param use_dill: Deprecated, use ``serializer`` instead. Whether to use 
dill to serialize
         the args and result (pickle is default). This allows more complex types
         but requires you to include dill in your requirements.
+    :param use_airflow_context: Whether to provide ``get_current_context()`` 
to the python_callable.
     """
 
     template_fields: Sequence[str] = tuple(
@@ -694,6 +719,7 @@ class 
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
         env_vars: dict[str, str] | None = None,
         inherit_env: bool = True,
         use_dill: bool = False,
+        use_airflow_context: bool = False,
         **kwargs,
     ):
         if (
@@ -715,6 +741,9 @@ class 
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
             )
         if not is_venv_installed():
             raise AirflowException("PythonVirtualenvOperator requires 
virtualenv, please install it.")
+        if use_airflow_context and (not expect_airflow and not 
system_site_packages):
+            error_msg = "use_airflow_context is set to True, but 
expect_airflow and system_site_packages are set to False."
+            raise AirflowException(error_msg)
         if not requirements:
             self.requirements: list[str] = []
         elif isinstance(requirements, str):
@@ -744,6 +773,7 @@ class 
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
             env_vars=env_vars,
             inherit_env=inherit_env,
             use_dill=use_dill,
+            use_airflow_context=use_airflow_context,
             **kwargs,
         )
 
@@ -962,6 +992,7 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator):
     :param use_dill: Deprecated, use ``serializer`` instead. Whether to use 
dill to serialize
         the args and result (pickle is default). This allows more complex types
         but requires you to include dill in your requirements.
+    :param use_airflow_context: Whether to provide ``get_current_context()`` 
to the python_callable.
     """
 
     template_fields: Sequence[str] = 
tuple({"python"}.union(PythonOperator.template_fields))
@@ -983,10 +1014,14 @@ class 
ExternalPythonOperator(_BasePythonVirtualenvOperator):
         env_vars: dict[str, str] | None = None,
         inherit_env: bool = True,
         use_dill: bool = False,
+        use_airflow_context: bool = False,
         **kwargs,
     ):
         if not python:
             raise ValueError("Python Path must be defined in 
ExternalPythonOperator")
+        if use_airflow_context and not expect_airflow:
+            error_msg = "use_airflow_context is set to True, but 
expect_airflow is set to False."
+            raise AirflowException(error_msg)
         self.python = python
         self.expect_pendulum = expect_pendulum
         super().__init__(
@@ -1002,6 +1037,7 @@ class 
ExternalPythonOperator(_BasePythonVirtualenvOperator):
             env_vars=env_vars,
             inherit_env=inherit_env,
             use_dill=use_dill,
+            use_airflow_context=use_airflow_context,
             **kwargs,
         )
 
diff --git a/airflow/utils/python_virtualenv_script.jinja2 
b/airflow/utils/python_virtualenv_script.jinja2
index 2ff417985e..22d68acd75 100644
--- a/airflow/utils/python_virtualenv_script.jinja2
+++ b/airflow/utils/python_virtualenv_script.jinja2
@@ -64,6 +64,29 @@ with open(sys.argv[3], "r") as file:
     virtualenv_string_args = list(map(lambda x: x.strip(), list(file)))
 {% endif %}
 
+{% if use_airflow_context | default(false) -%}
+if len(sys.argv) > 5:
+    import json
+    from types import ModuleType
+
+    from airflow.operators import python as airflow_python
+    from airflow.serialization.serialized_objects import BaseSerialization
+
+
+    class _MockPython(ModuleType):
+        @staticmethod
+        def get_current_context():
+            with open(sys.argv[5]) as file:
+                context = json.load(file)
+                return BaseSerialization.deserialize(context, 
use_pydantic_models=True)
+
+        def __getattr__(self, name: str):
+            return getattr(airflow_python, name)
+
+
+    MockPython = _MockPython("MockPython")
+    sys.modules["airflow.operators.python"] = MockPython
+{% endif %}
 
 try:
     res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"])
diff --git a/docs/apache-airflow/howto/operator/python.rst 
b/docs/apache-airflow/howto/operator/python.rst
index b8619cd38b..5b5a60b6bc 100644
--- a/docs/apache-airflow/howto/operator/python.rst
+++ b/docs/apache-airflow/howto/operator/python.rst
@@ -102,6 +102,37 @@ is evaluated as a :ref:`Jinja template 
<concepts:jinja-templating>`.
             :start-after: [START howto_operator_python_render_sql]
             :end-before: [END howto_operator_python_render_sql]
 
+Context
+^^^^^^^
+
+The ``Context`` is a dictionary object that contains information
+about the environment of the ``DagRun``.
+For example, selecting ``task_instance`` will get the currently running 
``TaskInstance`` object.
+
+It can be used implicitly, such as with ``**kwargs``,
+but can also be used explicitly with ``get_current_context()``.
+In this case, the type hint can be used for static analysis.
+
+.. tab-set::
+
+    .. tab-item:: @task
+        :sync: taskflow
+
+        .. exampleinclude:: 
/../../airflow/example_dags/example_python_context_decorator.py
+            :language: python
+            :dedent: 4
+            :start-after: [START get_current_context]
+            :end-before: [END get_current_context]
+
+    .. tab-item:: PythonOperator
+        :sync: operator
+
+        .. exampleinclude:: 
/../../airflow/example_dags/example_python_context_operator.py
+            :language: python
+            :dedent: 4
+            :start-after: [START get_current_context]
+            :end-before: [END get_current_context]
+
 .. _howto/operator:PythonVirtualenvOperator:
 
 PythonVirtualenvOperator
@@ -203,6 +234,42 @@ In case you have problems during runtime with broken 
cached virtual environments
 Note that any modification of a cached virtual environment (like temp files in 
binary path, post-installing further requirements) might pollute a cached 
virtual environment and the
 operator is not maintaining or cleaning the cache path.
 
+Context
+^^^^^^^
+
+With some limitations, you can also use ``Context`` in virtual environments.
+
+.. important::
+    Using ``Context`` in a virtual environment is a bit of a challenge
+    because it involves library dependencies and serialization issues.
+
+    You can bypass this to some extent by using :ref:`Jinja template variables 
<templates:variables>` and explicitly passing it as a parameter.
+
+    You can also use ``get_current_context()`` in the same way as before, but 
with some limitations.
+
+    * set ``use_airflow_context`` to ``True`` to call 
``get_current_context()`` in the virtual environment.
+
+    * set ``system_site_packages`` to ``True`` or set ``expect_airflow`` to 
``True``
+
+.. tab-set::
+
+    .. tab-item:: @task.virtualenv
+        :sync: taskflow
+
+        .. exampleinclude:: 
/../../airflow/example_dags/example_python_context_decorator.py
+            :language: python
+            :dedent: 4
+            :start-after: [START get_current_context_venv]
+            :end-before: [END get_current_context_venv]
+
+    .. tab-item:: PythonVirtualenvOperator
+        :sync: operator
+
+        .. exampleinclude:: 
/../../airflow/example_dags/example_python_context_operator.py
+            :language: python
+            :dedent: 4
+            :start-after: [START get_current_context_venv]
+            :end-before: [END get_current_context_venv]
 
 .. _howto/operator:ExternalPythonOperator:
 
@@ -267,6 +334,31 @@ If you want the context related to datetime objects like 
``data_interval_start``
     If you want to pass variables into the classic 
:class:`~airflow.operators.python.ExternalPythonOperator` use
     ``op_args`` and ``op_kwargs``.
 
+Context
+^^^^^^^
+
+You can use ``Context`` under the same conditions as 
``PythonVirtualenvOperator``.
+
+.. tab-set::
+
+    .. tab-item:: @task.external_python
+        :sync: taskflow
+
+        .. exampleinclude:: 
/../../airflow/example_dags/example_python_context_decorator.py
+            :language: python
+            :dedent: 4
+            :start-after: [START get_current_context_external]
+            :end-before: [END get_current_context_external]
+
+    .. tab-item:: ExternalPythonOperator
+        :sync: operator
+
+        .. exampleinclude:: 
/../../airflow/example_dags/example_python_context_operator.py
+            :language: python
+            :dedent: 4
+            :start-after: [START get_current_context_external]
+            :end-before: [END get_current_context_external]
+
 .. _howto/operator:PythonBranchOperator:
 
 PythonBranchOperator
diff --git a/newsfragments/41039.feature.rst b/newsfragments/41039.feature.rst
new file mode 100644
index 0000000000..c696d25f87
--- /dev/null
+++ b/newsfragments/41039.feature.rst
@@ -0,0 +1 @@
+Enable ``get_current_context()`` to work in virtual environments. The 
following ``Operators`` are affected: ``PythonVirtualenvOperator``, 
``BranchPythonVirtualenvOperator``, ``ExternalPythonOperator``, 
``BranchExternalPythonOperator``
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 993d70cad3..9148ae18b7 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -39,10 +39,15 @@ import pytest
 from slugify import slugify
 
 from airflow.decorators import task_group
-from airflow.exceptions import AirflowException, DeserializingResultError, 
RemovedInAirflow3Warning
+from airflow.exceptions import (
+    AirflowException,
+    DeserializingResultError,
+    RemovedInAirflow3Warning,
+)
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.dag import DAG
 from airflow.models.taskinstance import TaskInstance, clear_task_instances, 
set_current_context
+from airflow.operators.branch import BranchMixIn
 from airflow.operators.empty import EmptyOperator
 from airflow.operators.python import (
     BranchExternalPythonOperator,
@@ -1005,6 +1010,75 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
         task = self.run_as_task(f, env_vars={"MY_ENV_VAR": "EFGHI"}, 
inherit_env=True)
         assert task.execute_callable() == "EFGHI"
 
+    def test_branch_current_context(self):
+        if not issubclass(self.opcls, BranchMixIn):
+            pytest.skip("This test is only applicable to BranchMixIn")
+
+    def test_current_context(self):
+        def f():
+            from airflow.operators.python import get_current_context
+            from airflow.utils.context import Context
+
+            context = get_current_context()
+            if not isinstance(context, Context):  # type: ignore[misc]
+                error_msg = f"Expected Context, got {type(context)}"
+                raise TypeError(error_msg)
+
+            return []
+
+        ti = self.run_as_task(f, return_ti=True, multiple_outputs=False, 
use_airflow_context=True)
+        assert ti.state == TaskInstanceState.SUCCESS
+
+    def test_current_context_not_found_error(self):
+        def f():
+            from airflow.operators.python import get_current_context
+
+            get_current_context()
+            return []
+
+        with pytest.raises(
+            AirflowException,
+            match="Current context was requested but no context was found! "
+            "Are you running within an airflow task?",
+        ):
+            self.run_as_task(f, return_ti=True, multiple_outputs=False, 
use_airflow_context=False)
+
+    def test_current_context_airflow_not_found_error(self):
+        airflow_flag: dict[str, bool] = {"expect_airflow": False}
+        error_msg = "use_airflow_context is set to True, but expect_airflow is 
set to False."
+
+        if not issubclass(self.opcls, ExternalPythonOperator):
+            airflow_flag["system_site_packages"] = False
+            error_msg = "use_airflow_context is set to True, but 
expect_airflow and system_site_packages are set to False."
+
+        def f():
+            from airflow.operators.python import get_current_context
+
+            get_current_context()
+            return []
+
+        with pytest.raises(AirflowException, match=error_msg):
+            self.run_as_task(
+                f, return_ti=True, multiple_outputs=False, 
use_airflow_context=True, **airflow_flag
+            )
+
+    def test_use_airflow_context_touch_other_variables(self):
+        def f():
+            from airflow.operators.python import get_current_context
+            from airflow.utils.context import Context
+
+            context = get_current_context()
+            if not isinstance(context, Context):  # type: ignore[misc]
+                error_msg = f"Expected Context, got {type(context)}"
+                raise TypeError(error_msg)
+
+            from airflow.operators.python import PythonOperator  # noqa: F401
+
+            return []
+
+        ti = self.run_as_task(f, return_ti=True, multiple_outputs=False, 
use_airflow_context=True)
+        assert ti.state == TaskInstanceState.SUCCESS
+
 
 venv_cache_path = tempfile.mkdtemp(prefix="venv_cache_path")
 
@@ -1426,6 +1500,29 @@ class 
TestPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
 
         self.run_as_task(f, serializer=serializer, system_site_packages=False, 
requirements=None)
 
+    def test_current_context_system_site_packages(self, session):
+        def f():
+            from airflow.operators.python import get_current_context
+            from airflow.utils.context import Context
+
+            context = get_current_context()
+            if not isinstance(context, Context):  # type: ignore[misc]
+                error_msg = f"Expected Context, got {type(context)}"
+                raise TypeError(error_msg)
+
+            return []
+
+        ti = self.run_as_task(
+            f,
+            return_ti=True,
+            multiple_outputs=False,
+            use_airflow_context=True,
+            session=session,
+            expect_airflow=False,
+            system_site_packages=True,
+        )
+        assert ti.state == TaskInstanceState.SUCCESS
+
 
 # when venv tests are run in parallel to other test they create new processes 
and this might take
 # quite some time in shared docker environment and get some contention even 
between different containers
@@ -1745,6 +1842,29 @@ class 
TestBranchPythonVirtualenvOperator(BaseTestBranchPythonVirtualenvOperator)
                 kwargs["venv_cache_path"] = venv_cache_path
         return kwargs
 
+    def test_current_context_system_site_packages(self, session):
+        def f():
+            from airflow.operators.python import get_current_context
+            from airflow.utils.context import Context
+
+            context = get_current_context()
+            if not isinstance(context, Context):  # type: ignore[misc]
+                error_msg = f"Expected Context, got {type(context)}"
+                raise TypeError(error_msg)
+
+            return []
+
+        ti = self.run_as_task(
+            f,
+            return_ti=True,
+            multiple_outputs=False,
+            use_airflow_context=True,
+            session=session,
+            expect_airflow=False,
+            system_site_packages=True,
+        )
+        assert ti.state == TaskInstanceState.SUCCESS
+
 
 # when venv tests are run in parallel to other test they create new processes 
and this might take
 # quite some time in shared docker environment and get some contention even 
between different containers

Reply via email to