This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 279b45625b Update render filename to use internal API (#38558)
279b45625b is described below

commit 279b45625b99d6522ef97611f89c8353d05ca3a6
Author: Daniel Standish <[email protected]>
AuthorDate: Tue Apr 2 08:47:05 2024 -0700

    Update render filename to use internal API (#38558)
    
    For AIP-44. Previously we would always create a session every time. I 
refactor into sub function _render_filename_db_access the portion that needs 
the session.
---
 airflow/api_internal/endpoints/rpc_api_endpoint.py |  2 +
 airflow/utils/log/file_task_handler.py             | 58 +++++++++++++++-------
 tests/task/task_runner/test_task_runner.py         |  4 +-
 3 files changed, 44 insertions(+), 20 deletions(-)

diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py 
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 11bf5f9359..5074504b8d 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -45,6 +45,7 @@ def _initialize_map() -> dict[str, Callable]:
     from airflow.models.serialized_dag import SerializedDagModel
     from airflow.models.taskinstance import TaskInstance
     from airflow.secrets.metastore import MetastoreBackend
+    from airflow.utils.log.file_task_handler import FileTaskHandler
 
     functions: list[Callable] = [
         DagFileProcessor.update_import_errors,
@@ -55,6 +56,7 @@ def _initialize_map() -> dict[str, Callable]:
         DagModel.get_current,
         DagFileProcessorManager.clear_nonexistent_import_errors,
         DagWarning.purge_inactive_dag_warnings,
+        FileTaskHandler._render_filename_db_access,
         Job._add_to_db,
         Job._fetch_from_db,
         Job._kill,
diff --git a/airflow/utils/log/file_task_handler.py 
b/airflow/utils/log/file_task_handler.py
index a5a2da6062..95d6849b20 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -32,6 +32,7 @@ from urllib.parse import urljoin
 
 import pendulum
 
+from airflow.api_internal.internal_api_call import internal_api_call
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
 from airflow.executors.executor_loader import ExecutorLoader
@@ -39,11 +40,16 @@ from airflow.utils.context import Context
 from airflow.utils.helpers import parse_template_string, 
render_template_to_string
 from airflow.utils.log.logging_mixin import SetContextPropagate
 from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler
-from airflow.utils.session import create_session
+from airflow.utils.session import provide_session
 from airflow.utils.state import State, TaskInstanceState
 
 if TYPE_CHECKING:
+    from pendulum import DateTime
+
+    from airflow.models import DagRun
     from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+    from airflow.serialization.pydantic.dag_run import DagRunPydantic
+    from airflow.serialization.pydantic.taskinstance import 
TaskInstancePydantic
 
 logger = logging.getLogger(__name__)
 
@@ -134,14 +140,14 @@ def _interleave_logs(*logs):
         last = v
 
 
-def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance:
+def _ensure_ti(ti: TaskInstanceKey | TaskInstance | TaskInstancePydantic, 
session) -> TaskInstance:
     """Given TI | TIKey, return a TI object.
 
     Will raise exception if no TI is found in the database.
     """
-    from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+    from airflow.models.taskinstance import TaskInstance
 
-    if not isinstance(ti, TaskInstanceKey):
+    if isinstance(ti, TaskInstance):
         return ti
     val = (
         session.query(TaskInstance)
@@ -255,22 +261,33 @@ class FileTaskHandler(logging.Handler):
         if self.handler:
             self.handler.close()
 
-    def _render_filename(self, ti: TaskInstance | TaskInstanceKey, try_number: 
int) -> str:
+    @staticmethod
+    @internal_api_call
+    @provide_session
+    def _render_filename_db_access(
+        *, ti, try_number: int, session=None
+    ) -> tuple[DagRun | DagRunPydantic, TaskInstance | TaskInstancePydantic, 
str | None, str | None]:
+        ti = _ensure_ti(ti, session)
+        dag_run = ti.get_dagrun(session=session)
+        template = dag_run.get_log_template(session=session).filename
+        str_tpl, jinja_tpl = parse_template_string(template)
+        filename = None
+        if jinja_tpl:
+            if getattr(ti, "task", None) is not None:
+                context = ti.get_template_context(session=session)
+            else:
+                context = Context(ti=ti, ts=dag_run.logical_date.isoformat())
+            context["try_number"] = try_number
+            filename = render_template_to_string(jinja_tpl, context)
+        return dag_run, ti, str_tpl, filename
+
+    def _render_filename(
+        self, ti: TaskInstance | TaskInstanceKey | TaskInstancePydantic, 
try_number: int
+    ) -> str:
         """Return the worker log filename."""
-        with create_session() as session:
-            ti = _ensure_ti(ti, session)
-            dag_run = ti.get_dagrun(session=session)
-            template = dag_run.get_log_template(session=session).filename
-            str_tpl, jinja_tpl = parse_template_string(template)
-
-            if jinja_tpl:
-                if getattr(ti, "task", None) is not None:
-                    context = ti.get_template_context(session=session)
-                else:
-                    context = Context(ti=ti, 
ts=dag_run.logical_date.isoformat())
-                context["try_number"] = try_number
-                return render_template_to_string(jinja_tpl, context)
-
+        dag_run, ti, str_tpl, filename = 
self._render_filename_db_access(ti=ti, try_number=try_number)
+        if filename:
+            return filename
         if str_tpl:
             if ti.task is not None and ti.task.dag is not None:
                 dag = ti.task.dag
@@ -278,6 +295,9 @@ class FileTaskHandler(logging.Handler):
             else:
                 from airflow.timetables.base import DataInterval
 
+                if TYPE_CHECKING:
+                    assert isinstance(dag_run.data_interval_start, DateTime)
+                    assert isinstance(dag_run.data_interval_end, DateTime)
                 data_interval = DataInterval(dag_run.data_interval_start, 
dag_run.data_interval_end)
             if data_interval[0]:
                 data_interval_start = data_interval[0].isoformat()
diff --git a/tests/task/task_runner/test_task_runner.py 
b/tests/task/task_runner/test_task_runner.py
index 2d5794aff7..6214930e36 100644
--- a/tests/task/task_runner/test_task_runner.py
+++ b/tests/task/task_runner/test_task_runner.py
@@ -35,12 +35,14 @@ class TestGetTaskRunner:
     def test_should_have_valid_imports(self, import_path):
         assert import_string(import_path) is not None
 
+    @mock.patch("airflow.utils.log.file_task_handler._ensure_ti")
     @mock.patch("airflow.task.task_runner.base_task_runner.subprocess")
     @mock.patch("airflow.task.task_runner._TASK_RUNNER_NAME", 
"StandardTaskRunner")
-    def test_should_support_core_task_runner(self, mock_subprocess):
+    def test_should_support_core_task_runner(self, mock_subprocess, 
mock_ensure_ti):
         ti = mock.MagicMock(map_index=-1, run_as_user=None)
         ti.get_template_context.return_value = {"ti": ti}
         ti.get_dagrun.return_value.get_log_template.return_value.filename = 
"blah"
+        mock_ensure_ti.return_value = ti
         Job = mock.MagicMock(task_instance=ti)
         Job.job_type = None
         job_runner = LocalTaskJobRunner(job=Job, task_instance=ti)

Reply via email to