uranusjr commented on a change in pull request #17719:
URL: https://github.com/apache/airflow/pull/17719#discussion_r692030362



##########
File path: airflow/jobs/backfill_job.py
##########
@@ -645,16 +645,14 @@ def tabulate_ti_keys_set(set_ti_keys: 
Set[TaskInstanceKey]) -> str:
             # Sorting by execution date first
             sorted_ti_keys = sorted(
                 set_ti_keys,
-                key=lambda ti_key: (ti_key.execution_date, ti_key.dag_id, 
ti_key.task_id, ti_key.try_number),
+                key=lambda ti_key: (ti_key.run_id, ti_key.dag_id, 
ti_key.task_id, ti_key.try_number),

Review comment:
       Since we have a fk constraint and relationship, can we do some join and 
keep `TaskInstanceKey.execution_date` to be backed by the DagRun’s 
`execution_date` instead? (i.e. don’t change this function, but change how 
`TaskInstanceKey` is created.) The current change would break ordering if a 
DagRun has a custom run ID.

##########
File path: airflow/jobs/backfill_job.py
##########
@@ -645,16 +645,14 @@ def tabulate_ti_keys_set(set_ti_keys: 
Set[TaskInstanceKey]) -> str:
             # Sorting by execution date first
             sorted_ti_keys = sorted(
                 set_ti_keys,
-                key=lambda ti_key: (ti_key.execution_date, ti_key.dag_id, 
ti_key.task_id, ti_key.try_number),
+                key=lambda ti_key: (ti_key.run_id, ti_key.dag_id, 
ti_key.task_id, ti_key.try_number),
             )
             return tabulate(sorted_ti_keys, headers=["DAG ID", "Task ID", 
"Execution date", "Try number"])
 
         def tabulate_tis_set(set_tis: Set[TaskInstance]) -> str:
             # Sorting by execution date first
-            sorted_tis = sorted(
-                set_tis, key=lambda ti: (ti.execution_date, ti.dag_id, 
ti.task_id, ti.try_number)
-            )
-            tis_values = ((ti.dag_id, ti.task_id, ti.execution_date, 
ti.try_number) for ti in sorted_tis)
+            sorted_tis = sorted(set_tis, key=lambda ti: (ti.run_id, ti.dag_id, 
ti.task_id, ti.try_number))
+            tis_values = ((ti.dag_id, ti.task_id, ti.run_id, ti.try_number) 
for ti in sorted_tis)

Review comment:
       Similar to above, but here we need to do a join when these 
`TaskInstance` objects were fetched so `ti.dag_run.execution_date` can be cheap.

##########
File path: airflow/models/skipmixin.py
##########
@@ -91,7 +79,28 @@ def skip(
         if not tasks:
             return
 
-        self._set_state_to_skipped(dag_run, execution_date, tasks, session)
+        if execution_date and not dag_run:
+            from airflow.models.dagrun import DagRun
+
+            warnings.warn(
+                "Passing an execution_date to `skip()` is deprecated in favour 
of passing a dag_run",
+                DeprecationWarning,
+                stacklevel=2,
+            )

Review comment:
       I think we also need a warning when `dag_run` and `execution_date` are 
*both* passed to say `execution_date` is now ignored and should be removed. Or 
maybe we can check the value matches `dagrun.execution_date` and raise a 
`ValueError`. Or both of these.

##########
File path: airflow/cli/commands/task_command.py
##########
@@ -51,15 +51,17 @@
 from airflow.utils.session import create_session, provide_session
 
 
-def _get_ti(task, exec_date_or_run_id):
+@provide_session
+def _get_ti(task, exec_date_or_run_id, session):
     """Get the task instance through DagRun.run_id, if that fails, get the TI 
the old way"""
-    dag_run = task.dag.get_dagrun(run_id=exec_date_or_run_id)
+    dag_run = task.dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
     if not dag_run:
         try:
             execution_date = timezone.parse(exec_date_or_run_id)
-            ti = TaskInstance(task, execution_date)
-            ti.refresh_from_db()
-            return ti
+            dag_run.session.query(DagRun).filter(
+                DagRun.dag_id == task.dag_id,
+                DagRun.execution_date == execution_date,
+            ).one()

Review comment:
       ```suggestion
               session.query(DagRun).filter(
                   DagRun.dag_id == task.dag_id,
                   DagRun.execution_date == execution_date,
               ).one()
   ```
   
   ?

##########
File path: airflow/models/skipmixin.py
##########
@@ -37,36 +38,23 @@
 class SkipMixin(LoggingMixin):
     """A Mixin to skip Tasks Instances"""
 
-    def _set_state_to_skipped(self, dag_run, execution_date, tasks, session):
+    def _set_state_to_skipped(self, dag_run, tasks, session):

Review comment:
       ```suggestion
       def _set_state_to_skipped(self, dag_run: DagRun, tasks: 
Iterable[BaseOperator], session: Session):
   ```
   
   This might help catch lingering usages still not passing in a DagRun… (Also 
need to annotate `skip` below)

##########
File path: airflow/models/dagrun.py
##########
@@ -675,17 +671,15 @@ def verify_integrity(self, session: Session = None):
 
             if task.task_id not in task_ids:
                 Stats.incr(f"task_instance_created-{task.task_type}", 1, 1)
-                ti = TI(task, self.execution_date)
+                ti = TI(task, execution_date=None, run_id=self.run_id)
                 task_instance_mutation_hook(ti)
                 session.add(ti)
 
         try:
             session.flush()
         except IntegrityError as err:
             self.log.info(str(err))
-            self.log.info(
-                'Hit IntegrityError while creating the TIs for ' 
f'{dag.dag_id} - {self.execution_date}.'
-            )
+            self.log.info('Hit IntegrityError while creating the TIs for ' 
f'{dag.dag_id} - {self.run_id}.')

Review comment:
       ```suggestion
               self.log.info('Hit IntegrityError while creating the TIs for %s 
- %s.', dag.dag_id, self.run_id)
   ```
   
   ?

##########
File path: airflow/cli/commands/task_command.py
##########
@@ -51,15 +51,17 @@
 from airflow.utils.session import create_session, provide_session
 
 
-def _get_ti(task, exec_date_or_run_id):
+@provide_session
+def _get_ti(task, exec_date_or_run_id, session):
     """Get the task instance through DagRun.run_id, if that fails, get the TI 
the old way"""
-    dag_run = task.dag.get_dagrun(run_id=exec_date_or_run_id)
+    dag_run = task.dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
     if not dag_run:
         try:
             execution_date = timezone.parse(exec_date_or_run_id)
-            ti = TaskInstance(task, execution_date)
-            ti.refresh_from_db()
-            return ti
+            dag_run.session.query(DagRun).filter(
+                DagRun.dag_id == task.dag_id,
+                DagRun.execution_date == execution_date,
+            ).one()
         except (ParserError, TypeError):

Review comment:
       Can these still happen or should we change this to `NoResultFound`?

##########
File path: airflow/models/baseoperator.py
##########
@@ -1257,24 +1257,46 @@ def get_flat_relatives(self, upstream: bool = False):
         dag: DAG = self._dag
         return list(map(lambda task_id: dag.task_dict[task_id], 
self.get_flat_relative_ids(upstream)))
 
+    @provide_session
     def run(
         self,
         start_date: Optional[datetime] = None,
         end_date: Optional[datetime] = None,
         ignore_first_depends_on_past: bool = True,
         ignore_ti_state: bool = False,
         mark_success: bool = False,
+        session: Session = None,
     ) -> None:
         """Run a set of task instances for a date range."""
         start_date = start_date or self.start_date
         end_date = end_date or self.end_date or timezone.utcnow()
 
         for info in self.dag.iter_dagrun_infos_between(start_date, end_date, 
align=False):
             ignore_depends_on_past = info.logical_date == start_date and 
ignore_first_depends_on_past
-            TaskInstance(self, info.logical_date).run(
+            try:
+                ti = TaskInstance(self, info.logical_date)
+            except DagRunNotFound:

Review comment:
       It’s pretty surprising for a class init to throw an exception. I 
understand it’s unavoidable if `logical_date` is passed in in general, but 
perhaps her we can avoid doing it by getting the `DagRun` instance beforehand 
instead of passing in `info.logical_date`? We can also bulk-read the `run_id`s 
before the loop instead of reading them one by one.
   
   Something like
   
   ```python
   logical_dates = [
       info.logical_date
       for info in self.dag.iter_dagrun_infos_between(start_date, end_date, 
align=False)
   ]
   dagruns = {
       run.execution_date: run
       for run in 
session.query(DagRun).filter(DagRun.execution_date.in_(logical_dates))
   }
   for date in logical_dates:
       try:
           run = dagruns[date]
       except KeyError:
           # Create a "fake" run…
       ti = TaskInstance(self, run_id=run.run_id)
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to