This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 279b45625b Update render filename to use internal API (#38558)
279b45625b is described below
commit 279b45625b99d6522ef97611f89c8353d05ca3a6
Author: Daniel Standish <[email protected]>
AuthorDate: Tue Apr 2 08:47:05 2024 -0700
Update render filename to use internal API (#38558)
For AIP-44. Previously we would always create a session every time. I
refactor into sub function _render_filename_db_access the portion that needs
the session.
---
airflow/api_internal/endpoints/rpc_api_endpoint.py | 2 +
airflow/utils/log/file_task_handler.py | 58 +++++++++++++++-------
tests/task/task_runner/test_task_runner.py | 4 +-
3 files changed, 44 insertions(+), 20 deletions(-)
diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 11bf5f9359..5074504b8d 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -45,6 +45,7 @@ def _initialize_map() -> dict[str, Callable]:
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance
from airflow.secrets.metastore import MetastoreBackend
+ from airflow.utils.log.file_task_handler import FileTaskHandler
functions: list[Callable] = [
DagFileProcessor.update_import_errors,
@@ -55,6 +56,7 @@ def _initialize_map() -> dict[str, Callable]:
DagModel.get_current,
DagFileProcessorManager.clear_nonexistent_import_errors,
DagWarning.purge_inactive_dag_warnings,
+ FileTaskHandler._render_filename_db_access,
Job._add_to_db,
Job._fetch_from_db,
Job._kill,
diff --git a/airflow/utils/log/file_task_handler.py
b/airflow/utils/log/file_task_handler.py
index a5a2da6062..95d6849b20 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -32,6 +32,7 @@ from urllib.parse import urljoin
import pendulum
+from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.executors.executor_loader import ExecutorLoader
@@ -39,11 +40,16 @@ from airflow.utils.context import Context
from airflow.utils.helpers import parse_template_string,
render_template_to_string
from airflow.utils.log.logging_mixin import SetContextPropagate
from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler
-from airflow.utils.session import create_session
+from airflow.utils.session import provide_session
from airflow.utils.state import State, TaskInstanceState
if TYPE_CHECKING:
+ from pendulum import DateTime
+
+ from airflow.models import DagRun
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+ from airflow.serialization.pydantic.dag_run import DagRunPydantic
+ from airflow.serialization.pydantic.taskinstance import
TaskInstancePydantic
logger = logging.getLogger(__name__)
@@ -134,14 +140,14 @@ def _interleave_logs(*logs):
last = v
-def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance:
+def _ensure_ti(ti: TaskInstanceKey | TaskInstance | TaskInstancePydantic,
session) -> TaskInstance:
"""Given TI | TIKey, return a TI object.
Will raise exception if no TI is found in the database.
"""
- from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
+ from airflow.models.taskinstance import TaskInstance
- if not isinstance(ti, TaskInstanceKey):
+ if isinstance(ti, TaskInstance):
return ti
val = (
session.query(TaskInstance)
@@ -255,22 +261,33 @@ class FileTaskHandler(logging.Handler):
if self.handler:
self.handler.close()
- def _render_filename(self, ti: TaskInstance | TaskInstanceKey, try_number:
int) -> str:
+ @staticmethod
+ @internal_api_call
+ @provide_session
+ def _render_filename_db_access(
+ *, ti, try_number: int, session=None
+ ) -> tuple[DagRun | DagRunPydantic, TaskInstance | TaskInstancePydantic,
str | None, str | None]:
+ ti = _ensure_ti(ti, session)
+ dag_run = ti.get_dagrun(session=session)
+ template = dag_run.get_log_template(session=session).filename
+ str_tpl, jinja_tpl = parse_template_string(template)
+ filename = None
+ if jinja_tpl:
+ if getattr(ti, "task", None) is not None:
+ context = ti.get_template_context(session=session)
+ else:
+ context = Context(ti=ti, ts=dag_run.logical_date.isoformat())
+ context["try_number"] = try_number
+ filename = render_template_to_string(jinja_tpl, context)
+ return dag_run, ti, str_tpl, filename
+
+ def _render_filename(
+ self, ti: TaskInstance | TaskInstanceKey | TaskInstancePydantic,
try_number: int
+ ) -> str:
"""Return the worker log filename."""
- with create_session() as session:
- ti = _ensure_ti(ti, session)
- dag_run = ti.get_dagrun(session=session)
- template = dag_run.get_log_template(session=session).filename
- str_tpl, jinja_tpl = parse_template_string(template)
-
- if jinja_tpl:
- if getattr(ti, "task", None) is not None:
- context = ti.get_template_context(session=session)
- else:
- context = Context(ti=ti,
ts=dag_run.logical_date.isoformat())
- context["try_number"] = try_number
- return render_template_to_string(jinja_tpl, context)
-
+ dag_run, ti, str_tpl, filename =
self._render_filename_db_access(ti=ti, try_number=try_number)
+ if filename:
+ return filename
if str_tpl:
if ti.task is not None and ti.task.dag is not None:
dag = ti.task.dag
@@ -278,6 +295,9 @@ class FileTaskHandler(logging.Handler):
else:
from airflow.timetables.base import DataInterval
+ if TYPE_CHECKING:
+ assert isinstance(dag_run.data_interval_start, DateTime)
+ assert isinstance(dag_run.data_interval_end, DateTime)
data_interval = DataInterval(dag_run.data_interval_start,
dag_run.data_interval_end)
if data_interval[0]:
data_interval_start = data_interval[0].isoformat()
diff --git a/tests/task/task_runner/test_task_runner.py
b/tests/task/task_runner/test_task_runner.py
index 2d5794aff7..6214930e36 100644
--- a/tests/task/task_runner/test_task_runner.py
+++ b/tests/task/task_runner/test_task_runner.py
@@ -35,12 +35,14 @@ class TestGetTaskRunner:
def test_should_have_valid_imports(self, import_path):
assert import_string(import_path) is not None
+ @mock.patch("airflow.utils.log.file_task_handler._ensure_ti")
@mock.patch("airflow.task.task_runner.base_task_runner.subprocess")
@mock.patch("airflow.task.task_runner._TASK_RUNNER_NAME",
"StandardTaskRunner")
- def test_should_support_core_task_runner(self, mock_subprocess):
+ def test_should_support_core_task_runner(self, mock_subprocess,
mock_ensure_ti):
ti = mock.MagicMock(map_index=-1, run_as_user=None)
ti.get_template_context.return_value = {"ti": ti}
ti.get_dagrun.return_value.get_log_template.return_value.filename =
"blah"
+ mock_ensure_ti.return_value = ti
Job = mock.MagicMock(task_instance=ti)
Job.job_type = None
job_runner = LocalTaskJobRunner(job=Job, task_instance=ti)