This is an automated email from the ASF dual-hosted git repository. utkarsharma pushed a commit to branch v2-9-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 47b5bb67fb6195097950701f0570f6fa428a06bb Author: phi-friday <[email protected]> AuthorDate: Fri Jun 14 14:04:37 2024 +0900 Fix import future annotations in venv jinja template (#40208) (cherry picked from commit d5a75446a62ba8804879177ce394c7456adaa4d3) --- airflow/utils/python_virtualenv_script.jinja2 | 1 + tests/decorators/test_python_virtualenv.py | 30 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/airflow/utils/python_virtualenv_script.jinja2 b/airflow/utils/python_virtualenv_script.jinja2 index 4199a47130..2ff417985e 100644 --- a/airflow/utils/python_virtualenv_script.jinja2 +++ b/airflow/utils/python_virtualenv_script.jinja2 @@ -16,6 +16,7 @@ specific language governing permissions and limitations under the License. -#} +from __future__ import annotations import {{ pickling_library }} import sys diff --git a/tests/decorators/test_python_virtualenv.py b/tests/decorators/test_python_virtualenv.py index 09631aafe8..62e8b7405f 100644 --- a/tests/decorators/test_python_virtualenv.py +++ b/tests/decorators/test_python_virtualenv.py @@ -20,17 +20,21 @@ from __future__ import annotations import datetime import sys from subprocess import CalledProcessError +from typing import Any import pytest from airflow.decorators import setup, task, teardown from airflow.utils import timezone +from airflow.utils.state import TaskInstanceState pytestmark = pytest.mark.db_test DEFAULT_DATE = timezone.datetime(2016, 1, 1) PYTHON_VERSION = sys.version_info[0] +_Invalid = Any + class TestPythonVirtualenvDecorator: def test_add_dill(self, dag_maker): @@ -250,3 +254,29 @@ class TestPythonVirtualenvDecorator: assert teardown_task.is_teardown assert teardown_task.on_failure_fail_dagrun is on_failure_fail_dagrun ret.operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + def test_invalid_annotation(self, dag_maker): + import uuid + + unique_id = uuid.uuid4().hex + value = {"unique_id": unique_id} + + # Functions that throw an error + # if `from __future__ import annotations` is missing + @task.virtualenv(multiple_outputs=False, do_xcom_push=True) + def in_venv(value: dict[str, _Invalid]) -> _Invalid: + assert isinstance(value, dict) + return value["unique_id"] + + with dag_maker(): + ret = in_venv(value) + + dr = dag_maker.create_dagrun() + ret.operator.run(start_date=dr.execution_date, end_date=dr.execution_date) + ti = dr.get_task_instances()[0] + + assert ti.state == TaskInstanceState.SUCCESS + + xcom = ti.xcom_pull(task_ids=ti.task_id, key="return_value") + assert isinstance(xcom, str) + assert xcom == unique_id
