This is an automated email from the ASF dual-hosted git repository. jedcunningham pushed a commit to branch v2-2-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 901901a4d19e3fbe33b58881bf7e0b09e4982fed Author: Gulshan Gill <[email protected]> AuthorDate: Thu Nov 4 03:09:41 2021 +0800 Use ``execution_date`` to check for existing ``DagRun`` for ``TriggerDagRunOperator`` (#18968) A small suggestion to change `DagRun.find` in `trigger_dag` to use `execution_date` as a parameter rather than `run_id`. I feel it would be better to use this rather than `run_id` as a parameter since using `run_id` will miss out checking for a scheduled run that ran at the same `execution_date` and throw the error below when it tries to create a new run with the same `execution_date`: ``` sqlalchemy.exc.IntegrityError: (psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint "dag_run_dag_id_execution_date_key" ``` There is a constraint in `dag_run` called `dag_run_dag_id_execution_date_key` which can be found [here](https://github.com/apache/airflow/blob/c4f5233cd10ae03ee69fba861c8a6fa64e1f8a71/airflow/models/dagrun.py#L103). (cherry picked from commit e54ee6e0d38ca469be6ba686e32ce7a3a34d03ca) --- airflow/api/common/experimental/trigger_dag.py | 6 ++- airflow/models/dagrun.py | 65 +++++++++++++++++------ tests/api/common/experimental/test_trigger_dag.py | 6 +-- tests/models/test_dagrun.py | 23 ++++++++ 4 files changed, 78 insertions(+), 22 deletions(-) diff --git a/airflow/api/common/experimental/trigger_dag.py b/airflow/api/common/experimental/trigger_dag.py index 2e64f86..38a873c 100644 --- a/airflow/api/common/experimental/trigger_dag.py +++ b/airflow/api/common/experimental/trigger_dag.py @@ -68,10 +68,12 @@ def _trigger_dag( ) run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL, execution_date) - dag_run = DagRun.find(dag_id=dag_id, run_id=run_id) + dag_run = DagRun.find_duplicate(dag_id=dag_id, execution_date=execution_date, run_id=run_id) if dag_run: - raise DagRunAlreadyExists(f"Run id {run_id} already exists for dag id {dag_id}") + raise DagRunAlreadyExists( + f"A Dag Run already exists for dag id {dag_id} at {execution_date} with run id {run_id}" + ) run_conf = None if conf: diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 800720c..8d2ab2a 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -285,12 +285,13 @@ class DagRun(Base, LoggingMixin): query.limit(max_number), of=cls, session=session, **skip_locked(session=session) ) - @staticmethod + @classmethod @provide_session def find( + cls, dag_id: Optional[Union[str, List[str]]] = None, run_id: Optional[str] = None, - execution_date: Optional[datetime] = None, + execution_date: Optional[Union[datetime, List[datetime]]] = None, state: Optional[DagRunState] = None, external_trigger: Optional[bool] = None, no_backfills: bool = False, @@ -324,35 +325,65 @@ class DagRun(Base, LoggingMixin): :param execution_end_date: dag run that was executed until this date :type execution_end_date: datetime.datetime """ - DR = DagRun - - qry = session.query(DR) + qry = session.query(cls) dag_ids = [dag_id] if isinstance(dag_id, str) else dag_id if dag_ids: - qry = qry.filter(DR.dag_id.in_(dag_ids)) + qry = qry.filter(cls.dag_id.in_(dag_ids)) if run_id: - qry = qry.filter(DR.run_id == run_id) + qry = qry.filter(cls.run_id == run_id) if execution_date: if isinstance(execution_date, list): - qry = qry.filter(DR.execution_date.in_(execution_date)) + qry = qry.filter(cls.execution_date.in_(execution_date)) else: - qry = qry.filter(DR.execution_date == execution_date) + qry = qry.filter(cls.execution_date == execution_date) if execution_start_date and execution_end_date: - qry = qry.filter(DR.execution_date.between(execution_start_date, execution_end_date)) + qry = qry.filter(cls.execution_date.between(execution_start_date, execution_end_date)) elif execution_start_date: - qry = qry.filter(DR.execution_date >= execution_start_date) + qry = qry.filter(cls.execution_date >= execution_start_date) elif execution_end_date: - qry = qry.filter(DR.execution_date <= execution_end_date) + qry = qry.filter(cls.execution_date <= execution_end_date) if state: - qry = qry.filter(DR.state == state) + qry = qry.filter(cls.state == state) if external_trigger is not None: - qry = qry.filter(DR.external_trigger == external_trigger) + qry = qry.filter(cls.external_trigger == external_trigger) if run_type: - qry = qry.filter(DR.run_type == run_type) + qry = qry.filter(cls.run_type == run_type) if no_backfills: - qry = qry.filter(DR.run_type != DagRunType.BACKFILL_JOB) + qry = qry.filter(cls.run_type != DagRunType.BACKFILL_JOB) + + return qry.order_by(cls.execution_date).all() + + @classmethod + @provide_session + def find_duplicate( + cls, + dag_id: str, + run_id: str, + execution_date: datetime, + session: Session = None, + ) -> Optional['DagRun']: + """ + Return an existing run for the DAG with a specific run_id or execution_date. - return qry.order_by(DR.execution_date).all() + *None* is returned if no such DAG run is found. + + :param dag_id: the dag_id to find duplicates for + :type dag_id: str + :param run_id: defines the run id for this dag run + :type run_id: str + :param execution_date: the execution date + :type execution_date: datetime.datetime + :param session: database session + :type session: sqlalchemy.orm.session.Session + """ + return ( + session.query(cls) + .filter( + cls.dag_id == dag_id, + or_(cls.run_id == run_id, cls.execution_date == execution_date), + ) + .one_or_none() + ) @staticmethod def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str: diff --git a/tests/api/common/experimental/test_trigger_dag.py b/tests/api/common/experimental/test_trigger_dag.py index cbca935..2f16446 100644 --- a/tests/api/common/experimental/test_trigger_dag.py +++ b/tests/api/common/experimental/test_trigger_dag.py @@ -49,7 +49,7 @@ class TestTriggerDag(unittest.TestCase): dag = DAG(dag_id) dag_bag_mock.dags = [dag_id] dag_bag_mock.get_dag.return_value = dag - dag_run_mock.find.return_value = DagRun() + dag_run_mock.find_duplicate.return_value = DagRun() with pytest.raises(AirflowException): _trigger_dag(dag_id, dag_bag_mock) @@ -60,7 +60,7 @@ class TestTriggerDag(unittest.TestCase): dag_id = "trigger_dag" dag_bag_mock.dags = [dag_id] dag_bag_mock.get_dag.return_value = dag_mock - dag_run_mock.find.return_value = None + dag_run_mock.find_duplicate.return_value = None dag1 = mock.MagicMock(subdags=[]) dag2 = mock.MagicMock(subdags=[]) dag_mock.subdags = [dag1, dag2] @@ -76,7 +76,7 @@ class TestTriggerDag(unittest.TestCase): dag_id = "trigger_dag" dag_bag_mock.dags = [dag_id] dag_bag_mock.get_dag.return_value = dag_mock - dag_run_mock.find.return_value = None + dag_run_mock.find_duplicate.return_value = None dag1 = mock.MagicMock(subdags=[]) dag2 = mock.MagicMock(subdags=[dag1]) dag_mock.subdags = [dag1, dag2] diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index c4ef287..00799be 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -142,6 +142,29 @@ class TestDagRun(unittest.TestCase): assert 0 == len(models.DagRun.find(dag_id=dag_id2, external_trigger=True)) assert 1 == len(models.DagRun.find(dag_id=dag_id2, external_trigger=False)) + def test_dagrun_find_duplicate(self): + session = settings.Session() + now = timezone.utcnow() + + dag_id = "test_dagrun_find_duplicate" + dag_run = models.DagRun( + dag_id=dag_id, + run_id=dag_id, + run_type=DagRunType.MANUAL, + execution_date=now, + start_date=now, + state=State.RUNNING, + external_trigger=True, + ) + session.add(dag_run) + + session.commit() + + assert models.DagRun.find_duplicate(dag_id=dag_id, run_id=dag_id, execution_date=now) is not None + assert models.DagRun.find_duplicate(dag_id=dag_id, run_id=dag_id, execution_date=None) is not None + assert models.DagRun.find_duplicate(dag_id=dag_id, run_id=None, execution_date=now) is not None + assert models.DagRun.find_duplicate(dag_id=dag_id, run_id=None, execution_date=None) is None + def test_dagrun_success_when_all_skipped(self): """ Tests that a DAG run succeeds when all tasks are skipped
