sunank200 commented on code in PR #42404:
URL: https://github.com/apache/airflow/pull/42404#discussion_r1835986921


##########
airflow/models/dag.py:
##########
@@ -1337,42 +1322,35 @@ def set_task_group_state(
         """
         from airflow.api.common.mark_tasks import set_state
 
-        if not exactly_one(execution_date, run_id):
-            raise ValueError("Exactly one of execution_date or run_id must be 
provided")
-
         tasks_to_set_state: list[BaseOperator | tuple[BaseOperator, int]] = []
         task_ids: list[str] = []
 
-        if execution_date is None:
-            dag_run = session.scalars(
-                select(DagRun).where(DagRun.run_id == run_id, DagRun.dag_id == 
self.dag_id)
-            ).one()  # Raises an error if not found
-            resolve_execution_date = dag_run.execution_date
-        else:
-            resolve_execution_date = execution_date
-
-        end_date = resolve_execution_date if not future else None
-        start_date = resolve_execution_date if not past else None
-
         task_group_dict = self.task_group.get_task_group_dict()
         task_group = task_group_dict.get(group_id)
         if task_group is None:
             raise ValueError("TaskGroup {group_id} could not be found")
         tasks_to_set_state = [task for task in task_group.iter_tasks() if 
isinstance(task, BaseOperator)]
         task_ids = [task.task_id for task in task_group.iter_tasks()]
         dag_runs_query = select(DagRun.id).where(DagRun.dag_id == self.dag_id)
-        if start_date is None and end_date is None:
-            dag_runs_query = dag_runs_query.where(DagRun.execution_date == 
start_date)
-        else:
-            if start_date is not None:
-                dag_runs_query = dag_runs_query.where(DagRun.execution_date >= 
start_date)
-            if end_date is not None:
-                dag_runs_query = dag_runs_query.where(DagRun.execution_date <= 
end_date)
+
+        @cache
+        def get_logical_date() -> datetime:
+            stmt = select(DagRun.logical_date).where(DagRun.run_id == run_id, 
DagRun.dag_id == self.dag_id)
+            return session.scalars(stmt).one()  # Raises an error if not found
+
+        end_date = get_logical_date() if not future else None
+        start_date = get_logical_date() if not past else None

Review Comment:
   Changed it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to