This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new bd1e5c0a38a Remove global from task instance session (#58601)
bd1e5c0a38a is described below
commit bd1e5c0a38a199bbb45ccc4913721dff592335df
Author: Jens Scheffler <[email protected]>
AuthorDate: Fri Nov 28 02:50:20 2025 +0100
Remove global from task instance session (#58601)
* Remove global from task instance session
* Fix pytest in standard provider for back-compat
---
.../src/airflow/cli/commands/task_command.py | 3 +-
airflow-core/src/airflow/utils/db.py | 8 +-
.../src/airflow/utils/task_instance_session.py | 60 ---------------
.../tests/unit/models/test_renderedtifields.py | 89 ++++++++++------------
.../tests/unit/standard/decorators/test_python.py | 40 +++++++++-
.../tests/unit/weaviate/operators/test_weaviate.py | 16 ++--
6 files changed, 90 insertions(+), 126 deletions(-)
diff --git a/airflow-core/src/airflow/cli/commands/task_command.py
b/airflow-core/src/airflow/cli/commands/task_command.py
index 7b6f3595a54..da086e88d75 100644
--- a/airflow-core/src/airflow/cli/commands/task_command.py
+++ b/airflow-core/src/airflow/cli/commands/task_command.py
@@ -54,7 +54,6 @@ from airflow.utils.platform import getuser
from airflow.utils.providers_configuration_loader import
providers_configuration_loaded
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState, State
-from airflow.utils.task_instance_session import
set_current_task_instance_session
from airflow.utils.types import DagRunTriggeredByType, DagRunType
if TYPE_CHECKING:
@@ -441,7 +440,7 @@ def task_render(args, dag: DAG | None = None) -> None:
create_if_necessary="memory",
)
- with create_session() as session,
set_current_task_instance_session(session=session):
+ with create_session() as session:
context = ti.get_template_context(session=session)
task = dag.get_task(args.task_id)
# TODO (GH-52141): After sdk separation, ti.get_template_context()
would
diff --git a/airflow-core/src/airflow/utils/db.py
b/airflow-core/src/airflow/utils/db.py
index 748b44acb86..231a74966c8 100644
--- a/airflow-core/src/airflow/utils/db.py
+++ b/airflow-core/src/airflow/utils/db.py
@@ -60,7 +60,6 @@ from airflow.utils import helpers
from airflow.utils.db_manager import RunDBManager
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import get_dialect_name
-from airflow.utils.task_instance_session import
get_current_task_instance_session
USE_PSYCOPG3: bool
try:
@@ -1546,7 +1545,7 @@ class LazySelectSequence(Sequence[T]):
_select_asc: Select
_select_desc: Select
- _session: Session = attrs.field(kw_only=True,
factory=get_current_task_instance_session)
+ _session: Session
_len: int | None = attrs.field(init=False, default=None)
@classmethod
@@ -1555,7 +1554,7 @@ class LazySelectSequence(Sequence[T]):
select: Select,
*,
order_by: Sequence[ColumnElement],
- session: Session | None = None,
+ session: Session,
) -> Self:
s1 = select
for col in order_by:
@@ -1563,7 +1562,7 @@ class LazySelectSequence(Sequence[T]):
s2 = select
for col in order_by:
s2 = s2.order_by(col.desc())
- return cls(s1, s2, session=session or
get_current_task_instance_session())
+ return cls(s1, s2, session=session)
@staticmethod
def _rebuild_select(stmt: TextClause) -> Select:
@@ -1603,7 +1602,6 @@ class LazySelectSequence(Sequence[T]):
s1, s2, self._len = state
self._select_asc = self._rebuild_select(text(s1))
self._select_desc = self._rebuild_select(text(s2))
- self._session = get_current_task_instance_session()
def __bool__(self) -> bool:
return check_query_exists(self._select_asc, session=self._session)
diff --git a/airflow-core/src/airflow/utils/task_instance_session.py
b/airflow-core/src/airflow/utils/task_instance_session.py
deleted file mode 100644
index 019a752c773..00000000000
--- a/airflow-core/src/airflow/utils/task_instance_session.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from __future__ import annotations
-
-import contextlib
-import logging
-import traceback
-from typing import TYPE_CHECKING
-
-from airflow import settings
-
-if TYPE_CHECKING:
- from sqlalchemy.orm import Session
-
-__current_task_instance_session: Session | None = None
-
-log = logging.getLogger(__name__)
-
-
-def get_current_task_instance_session() -> Session:
- global __current_task_instance_session
- if not __current_task_instance_session:
- log.warning("No task session set for this task. Continuing but this
likely causes a resource leak.")
- log.warning("Please report this and stacktrace below to
https://github.com/apache/airflow/issues")
- for filename, line_number, name, line in traceback.extract_stack():
- log.warning('File: "%s", %s , in %s', filename, line_number, name)
- if line:
- log.warning(" %s", line.strip())
- __current_task_instance_session = settings.get_session()()
- return __current_task_instance_session
-
-
[email protected]
-def set_current_task_instance_session(session: Session):
- global __current_task_instance_session
- if __current_task_instance_session:
- raise RuntimeError(
- "Session already set for this task. "
- "You can only have one 'set_current_task_session' context manager
active at a time."
- )
- __current_task_instance_session = session
- try:
- yield
- finally:
- __current_task_instance_session = None
diff --git a/airflow-core/tests/unit/models/test_renderedtifields.py
b/airflow-core/tests/unit/models/test_renderedtifields.py
index 8083459b6be..f31ed7722c0 100644
--- a/airflow-core/tests/unit/models/test_renderedtifields.py
+++ b/airflow-core/tests/unit/models/test_renderedtifields.py
@@ -41,7 +41,6 @@ from airflow.providers.standard.operators.bash import
BashOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import task as task_decorator
from airflow.utils.state import TaskInstanceState
-from airflow.utils.task_instance_session import
set_current_task_instance_session
from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.db import clear_db_dags, clear_db_runs,
clear_rendered_ti_fields
@@ -250,32 +249,29 @@ class TestRenderedTaskInstanceFields:
Test that old records are deleted from rendered_task_instance_fields
table
for a given task_id and dag_id.
"""
- with set_current_task_instance_session(session=session):
- with dag_maker("test_delete_old_records") as dag:
- task = BashOperator(task_id="test", bash_command="echo {{ ds
}}")
- rtif_list = []
- for num in range(rtif_num):
- dr = dag_maker.create_dagrun(
- run_id=str(num), logical_date=dag.start_date +
timedelta(days=num)
- )
- ti = dr.task_instances[0]
- ti.task = task
- rtif_list.append(RTIF(ti))
+ with dag_maker("test_delete_old_records") as dag:
+ task = BashOperator(task_id="test", bash_command="echo {{ ds }}")
+ rtif_list = []
+ for num in range(rtif_num):
+ dr = dag_maker.create_dagrun(run_id=str(num),
logical_date=dag.start_date + timedelta(days=num))
+ ti = dr.task_instances[0]
+ ti.task = task
+ rtif_list.append(RTIF(ti))
- session.add_all(rtif_list)
- session.flush()
+ session.add_all(rtif_list)
+ session.flush()
- result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id,
RTIF.task_id == task.task_id).all()
+ result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id,
RTIF.task_id == task.task_id).all()
- for rtif in rtif_list:
- assert rtif in result
+ for rtif in rtif_list:
+ assert rtif in result
- assert rtif_num == len(result)
+ assert rtif_num == len(result)
- with assert_queries_count(expected_query_count):
- RTIF.delete_old_records(task_id=task.task_id,
dag_id=task.dag_id, num_to_keep=num_to_keep)
- result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id,
RTIF.task_id == task.task_id).all()
- assert remaining_rtifs == len(result)
+ with assert_queries_count(expected_query_count):
+ RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id,
num_to_keep=num_to_keep)
+ result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id,
RTIF.task_id == task.task_id).all()
+ assert remaining_rtifs == len(result)
@pytest.mark.parametrize(
("num_runs", "num_to_keep", "remaining_rtifs", "expected_query_count"),
@@ -292,35 +288,32 @@ class TestRenderedTaskInstanceFields:
Test that old records are deleted from rendered_task_instance_fields
table
for a given task_id and dag_id with mapped tasks.
"""
- with set_current_task_instance_session(session=session):
- with dag_maker("test_delete_old_records", session=session,
serialized=True) as dag:
- mapped =
BashOperator.partial(task_id="mapped").expand(bash_command=["a", "b"])
- for num in range(num_runs):
- dr = dag_maker.create_dagrun(
- run_id=f"run_{num}", logical_date=dag.start_date +
timedelta(days=num)
- )
+ with dag_maker("test_delete_old_records", session=session,
serialized=True) as dag:
+ mapped =
BashOperator.partial(task_id="mapped").expand(bash_command=["a", "b"])
+ for num in range(num_runs):
+ dr = dag_maker.create_dagrun(
+ run_id=f"run_{num}", logical_date=dag.start_date +
timedelta(days=num)
+ )
- TaskMap.expand_mapped_task(
- dag.task_dict[mapped.task_id], dr.run_id,
session=dag_maker.session
- )
- session.refresh(dr)
- for ti in dr.task_instances:
- ti.task = mapped
- session.add(RTIF(ti))
- session.flush()
+ TaskMap.expand_mapped_task(dag.task_dict[mapped.task_id],
dr.run_id, session=dag_maker.session)
+ session.refresh(dr)
+ for ti in dr.task_instances:
+ ti.task = mapped
+ session.add(RTIF(ti))
+ session.flush()
- result = session.query(RTIF).filter(RTIF.dag_id ==
dag.dag_id).all()
- assert len(result) == num_runs * 2
+ result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id).all()
+ assert len(result) == num_runs * 2
- with assert_queries_count(expected_query_count):
- RTIF.delete_old_records(
- task_id=mapped.task_id, dag_id=dr.dag_id,
num_to_keep=num_to_keep, session=session
- )
- result = session.query(RTIF).filter_by(dag_id=dag.dag_id,
task_id=mapped.task_id).all()
- rtif_num_runs = Counter(rtif.run_id for rtif in result)
- assert len(rtif_num_runs) == remaining_rtifs
- # Check that we have _all_ the data for each row
- assert len(result) == remaining_rtifs * 2
+ with assert_queries_count(expected_query_count):
+ RTIF.delete_old_records(
+ task_id=mapped.task_id, dag_id=dr.dag_id,
num_to_keep=num_to_keep, session=session
+ )
+ result = session.query(RTIF).filter_by(dag_id=dag.dag_id,
task_id=mapped.task_id).all()
+ rtif_num_runs = Counter(rtif.run_id for rtif in result)
+ assert len(rtif_num_runs) == remaining_rtifs
+ # Check that we have _all_ the data for each row
+ assert len(result) == remaining_rtifs * 2
def test_write(self, dag_maker):
"""
diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py
b/providers/standard/tests/unit/standard/decorators/test_python.py
index 391d4ae7f4d..cfc5df5ab2a 100644
--- a/providers/standard/tests/unit/standard/decorators/test_python.py
+++ b/providers/standard/tests/unit/standard/decorators/test_python.py
@@ -25,7 +25,6 @@ import pytest
from airflow.exceptions import AirflowException, XComNotFound
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
-from airflow.utils.task_instance_session import
set_current_task_instance_session
from tests_common.test_utils.version_compat import (
AIRFLOW_V_3_0_1,
@@ -818,10 +817,49 @@ def
test_mapped_decorator_unmap_merge_op_kwargs(dag_maker, session):
assert set(unmapped.op_kwargs) == {"arg1", "arg2"}
[email protected](not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2")
def test_mapped_render_template_fields(dag_maker, session):
@task_decorator
def fn(arg1, arg2): ...
+ with dag_maker(session=session):
+ task1 = BaseOperator(task_id="op1")
+ mapped = fn.partial(arg2="{{ ti.task_id }}").expand(arg1=task1.output)
+
+ dr = dag_maker.create_dagrun()
+ ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
+
+ ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session)
+
+ session.add(
+ TaskMap(
+ dag_id=dr.dag_id,
+ task_id=task1.task_id,
+ run_id=dr.run_id,
+ map_index=-1,
+ length=1,
+ keys=None,
+ )
+ )
+ session.flush()
+
+ mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id,
session=session)
+ mapped_ti.map_index = 0
+ assert isinstance(mapped_ti.task, MappedOperator)
+
mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
+ assert isinstance(mapped_ti.task, BaseOperator)
+
+ assert mapped_ti.task.op_kwargs["arg1"] == "{{ ds }}"
+ assert mapped_ti.task.op_kwargs["arg2"] == "fn"
+
+
[email protected](AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2")
+def test_mapped_render_template_fields_af2(dag_maker, session):
+ from airflow.utils.task_instance_session import
set_current_task_instance_session
+
+ @task_decorator
+ def fn(arg1, arg2): ...
+
with set_current_task_instance_session(session=session):
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
diff --git a/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
b/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
index aa3be7a3cd8..28b5164c626 100644
--- a/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
+++ b/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
@@ -20,8 +20,6 @@ from unittest.mock import MagicMock, patch
import pytest
-from airflow.utils.task_instance_session import
set_current_task_instance_session
-
pytest.importorskip("weaviate")
from airflow.providers.weaviate.operators.weaviate import (
@@ -87,10 +85,9 @@ class TestWeaviateIngestOperator:
dr = dag_maker.create_dagrun()
tis = dr.get_task_instances(session=session)
- with set_current_task_instance_session(session=session):
- for ti in tis:
- ti.render_templates()
- assert ti.task.hook_params == {"baz": "biz"}
+ for ti in tis:
+ ti.render_templates()
+ assert ti.task.hook_params == {"baz": "biz"}
class TestWeaviateDocumentIngestOperator:
@@ -147,7 +144,6 @@ class TestWeaviateDocumentIngestOperator:
dr = dag_maker.create_dagrun()
tis = dr.get_task_instances(session=session)
- with set_current_task_instance_session(session=session):
- for ti in tis:
- ti.render_templates()
- assert ti.task.hook_params == {"baz": "biz"}
+ for ti in tis:
+ ti.render_templates()
+ assert ti.task.hook_params == {"baz": "biz"}