This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7947b72  Rewrite DAG run retrieval in task command (#20737)
7947b72 is described below

commit 7947b72eee61a4596c5d8667f8442d32dcbf3f6d
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Sat Jan 8 17:09:02 2022 +0800

    Rewrite DAG run retrieval in task command (#20737)
---
 airflow/cli/commands/task_command.py | 87 ++++++++++++++++++++++++++----------
 1 file changed, 63 insertions(+), 24 deletions(-)

diff --git a/airflow/cli/commands/task_command.py 
b/airflow/cli/commands/task_command.py
index 9ea4f4d..9458d05 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -16,24 +16,27 @@
 # specific language governing permissions and limitations
 # under the License.
 """Task sub-commands"""
+import datetime
 import importlib
 import json
 import logging
 import os
 import textwrap
-from contextlib import contextmanager, redirect_stderr, redirect_stdout, 
suppress
+from contextlib import contextmanager, redirect_stderr, redirect_stdout
 from typing import List, Optional
 
 from pendulum.parsing.exceptions import ParserError
 from sqlalchemy.orm.exc import NoResultFound
+from sqlalchemy.orm.session import Session
 
 from airflow import settings
 from airflow.cli.simple_table import AirflowConsole
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException, DagRunNotFound
+from airflow.exceptions import AirflowException, DagRunNotFound, 
TaskInstanceNotFound
 from airflow.executors.executor_loader import ExecutorLoader
 from airflow.jobs.local_task_job import LocalTaskJob
 from airflow.models import DagPickle, TaskInstance
+from airflow.models.baseoperator import BaseOperator
 from airflow.models.dag import DAG
 from airflow.models.dagrun import DagRun
 from airflow.models.xcom import IN_MEMORY_DAGRUN_ID
@@ -50,46 +53,82 @@ from airflow.utils.cli import (
 from airflow.utils.dates import timezone
 from airflow.utils.log.logging_mixin import StreamLogWriter
 from airflow.utils.net import get_hostname
-from airflow.utils.session import create_session, provide_session
-
-
-def _get_dag_run(dag, exec_date_or_run_id, create_if_necessary, session):
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
+
+
+def _get_dag_run(
+    *,
+    dag: DAG,
+    exec_date_or_run_id: str,
+    create_if_necessary: bool,
+    session: Session,
+) -> DagRun:
+    """Try to retrieve a DAG run from a string representing either a run ID or 
logical date.
+
+    This checks DAG runs like this:
+
+    1. If the input ``exec_date_or_run_id`` matches a DAG run ID, return the 
run.
+    2. Try to parse the input as a date. If that works, and the resulting
+       date matches a DAG run's logical date, return the run.
+    3. If ``create_if_necessary`` is *False* and the input works for neither of
+       the above, raise ``DagRunNotFound``.
+    4. Try to create a new DAG run. If the input looks like a date, use it as
+       the logical date; otherwise use it as a run ID and set the logical date
+       to the current time.
+    """
     dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
     if dag_run:
         return dag_run
 
-    execution_date = None
-    with suppress(ParserError, TypeError):
-        execution_date = timezone.parse(exec_date_or_run_id)
+    try:
+        execution_date: Optional[datetime.datetime] = 
timezone.parse(exec_date_or_run_id)
+    except (ParserError, TypeError):
+        execution_date = None
 
-    if create_if_necessary and not execution_date:
-        return DagRun(dag_id=dag.dag_id, run_id=exec_date_or_run_id)
     try:
         return (
             session.query(DagRun)
-            .filter(
-                DagRun.dag_id == dag.dag_id,
-                DagRun.execution_date == execution_date,
-            )
+            .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == 
execution_date)
             .one()
         )
     except NoResultFound:
-        if create_if_necessary:
-            return DagRun(dag.dag_id, run_id=IN_MEMORY_DAGRUN_ID, 
execution_date=execution_date)
-        raise DagRunNotFound(
-            f"DagRun for {dag.dag_id} with run_id or execution_date of 
{exec_date_or_run_id!r} not found"
-        ) from None
+        if not create_if_necessary:
+            raise DagRunNotFound(
+                f"DagRun for {dag.dag_id} with run_id or execution_date of 
{exec_date_or_run_id!r} not found"
+            ) from None
+
+    if execution_date is not None:
+        return DagRun(dag.dag_id, run_id=IN_MEMORY_DAGRUN_ID, 
execution_date=execution_date)
+    return DagRun(dag.dag_id, run_id=exec_date_or_run_id, 
execution_date=timezone.utcnow())
 
 
 @provide_session
-def _get_ti(task, exec_date_or_run_id, create_if_necessary=False, 
session=None):
+def _get_ti(
+    task: BaseOperator,
+    exec_date_or_run_id: str,
+    *,
+    create_if_necessary: bool = False,
+    session: Session = NEW_SESSION,
+) -> TaskInstance:
     """Get the task instance through DagRun.run_id, if that fails, get the TI 
the old way"""
-    dag_run = _get_dag_run(task.dag, exec_date_or_run_id, create_if_necessary, 
session)
+    dag_run = _get_dag_run(
+        dag=task.dag,
+        exec_date_or_run_id=exec_date_or_run_id,
+        create_if_necessary=create_if_necessary,
+        session=session,
+    )
 
-    ti = dag_run.get_task_instance(task.task_id)
-    if not ti and create_if_necessary:
+    ti_or_none = dag_run.get_task_instance(task.task_id)
+    if ti_or_none is None:
+        if not create_if_necessary:
+            raise TaskInstanceNotFound(
+                f"TaskInstance for {task.dag.dag_id}, {task.task_id} with "
+                f"run_id or execution_date of {exec_date_or_run_id!r} not 
found"
+            )
         ti = TaskInstance(task, run_id=dag_run.run_id)
         ti.dag_run = dag_run
+    else:
+        ti = ti_or_none
     ti.refresh_from_task(task)
     return ti
 

Reply via email to