This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new bd1e5c0a38a Remove global from task instance session (#58601)
bd1e5c0a38a is described below

commit bd1e5c0a38a199bbb45ccc4913721dff592335df
Author: Jens Scheffler <[email protected]>
AuthorDate: Fri Nov 28 02:50:20 2025 +0100

    Remove global from task instance session (#58601)
    
    * Remove global from task instance session
    
    * Fix pytest in standard provider for back-compat
---
 .../src/airflow/cli/commands/task_command.py       |  3 +-
 airflow-core/src/airflow/utils/db.py               |  8 +-
 .../src/airflow/utils/task_instance_session.py     | 60 ---------------
 .../tests/unit/models/test_renderedtifields.py     | 89 ++++++++++------------
 .../tests/unit/standard/decorators/test_python.py  | 40 +++++++++-
 .../tests/unit/weaviate/operators/test_weaviate.py | 16 ++--
 6 files changed, 90 insertions(+), 126 deletions(-)

diff --git a/airflow-core/src/airflow/cli/commands/task_command.py 
b/airflow-core/src/airflow/cli/commands/task_command.py
index 7b6f3595a54..da086e88d75 100644
--- a/airflow-core/src/airflow/cli/commands/task_command.py
+++ b/airflow-core/src/airflow/cli/commands/task_command.py
@@ -54,7 +54,6 @@ from airflow.utils.platform import getuser
 from airflow.utils.providers_configuration_loader import 
providers_configuration_loaded
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
 from airflow.utils.state import DagRunState, State
-from airflow.utils.task_instance_session import 
set_current_task_instance_session
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
 if TYPE_CHECKING:
@@ -441,7 +440,7 @@ def task_render(args, dag: DAG | None = None) -> None:
         create_if_necessary="memory",
     )
 
-    with create_session() as session, 
set_current_task_instance_session(session=session):
+    with create_session() as session:
         context = ti.get_template_context(session=session)
         task = dag.get_task(args.task_id)
         # TODO (GH-52141): After sdk separation, ti.get_template_context() 
would
diff --git a/airflow-core/src/airflow/utils/db.py 
b/airflow-core/src/airflow/utils/db.py
index 748b44acb86..231a74966c8 100644
--- a/airflow-core/src/airflow/utils/db.py
+++ b/airflow-core/src/airflow/utils/db.py
@@ -60,7 +60,6 @@ from airflow.utils import helpers
 from airflow.utils.db_manager import RunDBManager
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import get_dialect_name
-from airflow.utils.task_instance_session import 
get_current_task_instance_session
 
 USE_PSYCOPG3: bool
 try:
@@ -1546,7 +1545,7 @@ class LazySelectSequence(Sequence[T]):
 
     _select_asc: Select
     _select_desc: Select
-    _session: Session = attrs.field(kw_only=True, 
factory=get_current_task_instance_session)
+    _session: Session
     _len: int | None = attrs.field(init=False, default=None)
 
     @classmethod
@@ -1555,7 +1554,7 @@ class LazySelectSequence(Sequence[T]):
         select: Select,
         *,
         order_by: Sequence[ColumnElement],
-        session: Session | None = None,
+        session: Session,
     ) -> Self:
         s1 = select
         for col in order_by:
@@ -1563,7 +1562,7 @@ class LazySelectSequence(Sequence[T]):
         s2 = select
         for col in order_by:
             s2 = s2.order_by(col.desc())
-        return cls(s1, s2, session=session or 
get_current_task_instance_session())
+        return cls(s1, s2, session=session)
 
     @staticmethod
     def _rebuild_select(stmt: TextClause) -> Select:
@@ -1603,7 +1602,6 @@ class LazySelectSequence(Sequence[T]):
         s1, s2, self._len = state
         self._select_asc = self._rebuild_select(text(s1))
         self._select_desc = self._rebuild_select(text(s2))
-        self._session = get_current_task_instance_session()
 
     def __bool__(self) -> bool:
         return check_query_exists(self._select_asc, session=self._session)
diff --git a/airflow-core/src/airflow/utils/task_instance_session.py 
b/airflow-core/src/airflow/utils/task_instance_session.py
deleted file mode 100644
index 019a752c773..00000000000
--- a/airflow-core/src/airflow/utils/task_instance_session.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from __future__ import annotations
-
-import contextlib
-import logging
-import traceback
-from typing import TYPE_CHECKING
-
-from airflow import settings
-
-if TYPE_CHECKING:
-    from sqlalchemy.orm import Session
-
-__current_task_instance_session: Session | None = None
-
-log = logging.getLogger(__name__)
-
-
-def get_current_task_instance_session() -> Session:
-    global __current_task_instance_session
-    if not __current_task_instance_session:
-        log.warning("No task session set for this task. Continuing but this 
likely causes a resource leak.")
-        log.warning("Please report this and stacktrace below to 
https://github.com/apache/airflow/issues";)
-        for filename, line_number, name, line in traceback.extract_stack():
-            log.warning('File: "%s", %s , in %s', filename, line_number, name)
-            if line:
-                log.warning("  %s", line.strip())
-        __current_task_instance_session = settings.get_session()()
-    return __current_task_instance_session
-
-
[email protected]
-def set_current_task_instance_session(session: Session):
-    global __current_task_instance_session
-    if __current_task_instance_session:
-        raise RuntimeError(
-            "Session already set for this task. "
-            "You can only have one 'set_current_task_session' context manager 
active at a time."
-        )
-    __current_task_instance_session = session
-    try:
-        yield
-    finally:
-        __current_task_instance_session = None
diff --git a/airflow-core/tests/unit/models/test_renderedtifields.py 
b/airflow-core/tests/unit/models/test_renderedtifields.py
index 8083459b6be..f31ed7722c0 100644
--- a/airflow-core/tests/unit/models/test_renderedtifields.py
+++ b/airflow-core/tests/unit/models/test_renderedtifields.py
@@ -41,7 +41,6 @@ from airflow.providers.standard.operators.bash import 
BashOperator
 from airflow.providers.standard.operators.python import PythonOperator
 from airflow.sdk import task as task_decorator
 from airflow.utils.state import TaskInstanceState
-from airflow.utils.task_instance_session import 
set_current_task_instance_session
 
 from tests_common.test_utils.asserts import assert_queries_count
 from tests_common.test_utils.db import clear_db_dags, clear_db_runs, 
clear_rendered_ti_fields
@@ -250,32 +249,29 @@ class TestRenderedTaskInstanceFields:
         Test that old records are deleted from rendered_task_instance_fields 
table
         for a given task_id and dag_id.
         """
-        with set_current_task_instance_session(session=session):
-            with dag_maker("test_delete_old_records") as dag:
-                task = BashOperator(task_id="test", bash_command="echo {{ ds 
}}")
-            rtif_list = []
-            for num in range(rtif_num):
-                dr = dag_maker.create_dagrun(
-                    run_id=str(num), logical_date=dag.start_date + 
timedelta(days=num)
-                )
-                ti = dr.task_instances[0]
-                ti.task = task
-                rtif_list.append(RTIF(ti))
+        with dag_maker("test_delete_old_records") as dag:
+            task = BashOperator(task_id="test", bash_command="echo {{ ds }}")
+        rtif_list = []
+        for num in range(rtif_num):
+            dr = dag_maker.create_dagrun(run_id=str(num), 
logical_date=dag.start_date + timedelta(days=num))
+            ti = dr.task_instances[0]
+            ti.task = task
+            rtif_list.append(RTIF(ti))
 
-            session.add_all(rtif_list)
-            session.flush()
+        session.add_all(rtif_list)
+        session.flush()
 
-            result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, 
RTIF.task_id == task.task_id).all()
+        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, 
RTIF.task_id == task.task_id).all()
 
-            for rtif in rtif_list:
-                assert rtif in result
+        for rtif in rtif_list:
+            assert rtif in result
 
-            assert rtif_num == len(result)
+        assert rtif_num == len(result)
 
-            with assert_queries_count(expected_query_count):
-                RTIF.delete_old_records(task_id=task.task_id, 
dag_id=task.dag_id, num_to_keep=num_to_keep)
-            result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, 
RTIF.task_id == task.task_id).all()
-            assert remaining_rtifs == len(result)
+        with assert_queries_count(expected_query_count):
+            RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, 
num_to_keep=num_to_keep)
+        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, 
RTIF.task_id == task.task_id).all()
+        assert remaining_rtifs == len(result)
 
     @pytest.mark.parametrize(
         ("num_runs", "num_to_keep", "remaining_rtifs", "expected_query_count"),
@@ -292,35 +288,32 @@ class TestRenderedTaskInstanceFields:
         Test that old records are deleted from rendered_task_instance_fields 
table
         for a given task_id and dag_id with mapped tasks.
         """
-        with set_current_task_instance_session(session=session):
-            with dag_maker("test_delete_old_records", session=session, 
serialized=True) as dag:
-                mapped = 
BashOperator.partial(task_id="mapped").expand(bash_command=["a", "b"])
-            for num in range(num_runs):
-                dr = dag_maker.create_dagrun(
-                    run_id=f"run_{num}", logical_date=dag.start_date + 
timedelta(days=num)
-                )
+        with dag_maker("test_delete_old_records", session=session, 
serialized=True) as dag:
+            mapped = 
BashOperator.partial(task_id="mapped").expand(bash_command=["a", "b"])
+        for num in range(num_runs):
+            dr = dag_maker.create_dagrun(
+                run_id=f"run_{num}", logical_date=dag.start_date + 
timedelta(days=num)
+            )
 
-                TaskMap.expand_mapped_task(
-                    dag.task_dict[mapped.task_id], dr.run_id, 
session=dag_maker.session
-                )
-                session.refresh(dr)
-                for ti in dr.task_instances:
-                    ti.task = mapped
-                    session.add(RTIF(ti))
-            session.flush()
+            TaskMap.expand_mapped_task(dag.task_dict[mapped.task_id], 
dr.run_id, session=dag_maker.session)
+            session.refresh(dr)
+            for ti in dr.task_instances:
+                ti.task = mapped
+                session.add(RTIF(ti))
+        session.flush()
 
-            result = session.query(RTIF).filter(RTIF.dag_id == 
dag.dag_id).all()
-            assert len(result) == num_runs * 2
+        result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id).all()
+        assert len(result) == num_runs * 2
 
-            with assert_queries_count(expected_query_count):
-                RTIF.delete_old_records(
-                    task_id=mapped.task_id, dag_id=dr.dag_id, 
num_to_keep=num_to_keep, session=session
-                )
-            result = session.query(RTIF).filter_by(dag_id=dag.dag_id, 
task_id=mapped.task_id).all()
-            rtif_num_runs = Counter(rtif.run_id for rtif in result)
-            assert len(rtif_num_runs) == remaining_rtifs
-            # Check that we have _all_ the data for each row
-            assert len(result) == remaining_rtifs * 2
+        with assert_queries_count(expected_query_count):
+            RTIF.delete_old_records(
+                task_id=mapped.task_id, dag_id=dr.dag_id, 
num_to_keep=num_to_keep, session=session
+            )
+        result = session.query(RTIF).filter_by(dag_id=dag.dag_id, 
task_id=mapped.task_id).all()
+        rtif_num_runs = Counter(rtif.run_id for rtif in result)
+        assert len(rtif_num_runs) == remaining_rtifs
+        # Check that we have _all_ the data for each row
+        assert len(result) == remaining_rtifs * 2
 
     def test_write(self, dag_maker):
         """
diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py 
b/providers/standard/tests/unit/standard/decorators/test_python.py
index 391d4ae7f4d..cfc5df5ab2a 100644
--- a/providers/standard/tests/unit/standard/decorators/test_python.py
+++ b/providers/standard/tests/unit/standard/decorators/test_python.py
@@ -25,7 +25,6 @@ import pytest
 from airflow.exceptions import AirflowException, XComNotFound
 from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskmap import TaskMap
-from airflow.utils.task_instance_session import 
set_current_task_instance_session
 
 from tests_common.test_utils.version_compat import (
     AIRFLOW_V_3_0_1,
@@ -818,10 +817,49 @@ def 
test_mapped_decorator_unmap_merge_op_kwargs(dag_maker, session):
     assert set(unmapped.op_kwargs) == {"arg1", "arg2"}
 
 
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2")
 def test_mapped_render_template_fields(dag_maker, session):
     @task_decorator
     def fn(arg1, arg2): ...
 
+    with dag_maker(session=session):
+        task1 = BaseOperator(task_id="op1")
+        mapped = fn.partial(arg2="{{ ti.task_id }}").expand(arg1=task1.output)
+
+    dr = dag_maker.create_dagrun()
+    ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
+
+    ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session)
+
+    session.add(
+        TaskMap(
+            dag_id=dr.dag_id,
+            task_id=task1.task_id,
+            run_id=dr.run_id,
+            map_index=-1,
+            length=1,
+            keys=None,
+        )
+    )
+    session.flush()
+
+    mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, 
session=session)
+    mapped_ti.map_index = 0
+    assert isinstance(mapped_ti.task, MappedOperator)
+    
mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
+    assert isinstance(mapped_ti.task, BaseOperator)
+
+    assert mapped_ti.task.op_kwargs["arg1"] == "{{ ds }}"
+    assert mapped_ti.task.op_kwargs["arg2"] == "fn"
+
+
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2")
+def test_mapped_render_template_fields_af2(dag_maker, session):
+    from airflow.utils.task_instance_session import 
set_current_task_instance_session
+
+    @task_decorator
+    def fn(arg1, arg2): ...
+
     with set_current_task_instance_session(session=session):
         with dag_maker(session=session):
             task1 = BaseOperator(task_id="op1")
diff --git a/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py 
b/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
index aa3be7a3cd8..28b5164c626 100644
--- a/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
+++ b/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
@@ -20,8 +20,6 @@ from unittest.mock import MagicMock, patch
 
 import pytest
 
-from airflow.utils.task_instance_session import 
set_current_task_instance_session
-
 pytest.importorskip("weaviate")
 
 from airflow.providers.weaviate.operators.weaviate import (
@@ -87,10 +85,9 @@ class TestWeaviateIngestOperator:
 
         dr = dag_maker.create_dagrun()
         tis = dr.get_task_instances(session=session)
-        with set_current_task_instance_session(session=session):
-            for ti in tis:
-                ti.render_templates()
-                assert ti.task.hook_params == {"baz": "biz"}
+        for ti in tis:
+            ti.render_templates()
+            assert ti.task.hook_params == {"baz": "biz"}
 
 
 class TestWeaviateDocumentIngestOperator:
@@ -147,7 +144,6 @@ class TestWeaviateDocumentIngestOperator:
 
         dr = dag_maker.create_dagrun()
         tis = dr.get_task_instances(session=session)
-        with set_current_task_instance_session(session=session):
-            for ti in tis:
-                ti.render_templates()
-                assert ti.task.hook_params == {"baz": "biz"}
+        for ti in tis:
+            ti.render_templates()
+            assert ti.task.hook_params == {"baz": "biz"}

Reply via email to