This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit eceae90d28b8490499640763fa3ec3c0cb01c989
Author: Jarek Potiuk <[email protected]>
AuthorDate: Mon Jun 29 16:43:45 2020 +0200

    [AIRFLOW-6957] Make retrieving Paused Dag ids a separate method
    
       (cherry picked from commit a887e0a1a02e12e00687ff123220de095e560647)
---
 airflow/jobs/scheduler_job.py |  2 +-
 airflow/models/dag.py         | 20 ++++++++++++++++++++
 tests/models/test_dag.py      | 15 +++++++++++++++
 3 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 0665779..5b06be8 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -1266,7 +1266,7 @@ class SchedulerJob(BaseJob):
         :param dagbag: a collection of DAGs to process
         :type dagbag: airflow.models.DagBag
         :param dags: the DAGs from the DagBag to process
-        :type dags: airflow.models.DAG
+        :type dags: list[airflow.models.DAG]
         :param tis_out: A list to add generated TaskInstance objects
         :type tis_out: list[TaskInstance]
         :rtype: None
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 3f46bf6..7759cb3 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -1785,6 +1785,26 @@ class DagModel(Base):
         return get_last_dagrun(self.dag_id, session=session,
                                
include_externally_triggered=include_externally_triggered)
 
+    @staticmethod
+    @provide_session
+    def get_paused_dag_ids(dag_ids, session):
+        """
+        Given a list of dag_ids, get a set of Paused Dag Ids
+
+        :param dag_ids: List of Dag ids
+        :param session: ORM Session
+        :return: Paused Dag_ids
+        """
+        paused_dag_ids = (
+            session.query(DagModel.dag_id)
+            .filter(DagModel.is_paused.is_(True))
+            .filter(DagModel.dag_id.in_(dag_ids))
+            .all()
+        )
+
+        paused_dag_ids = set(paused_dag_id for paused_dag_id, in 
paused_dag_ids)
+        return paused_dag_ids
+
     @property
     def safe_dag_id(self):
         return self.dag_id.replace('.', '__dot__')
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index cdbe1ee..5d9d05d 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -42,6 +42,7 @@ from airflow.operators.dummy_operator import DummyOperator
 from airflow.operators.subdag_operator import SubDagOperator
 from airflow.utils import timezone
 from airflow.utils.dag_processing import list_py_file_paths
+from airflow.utils.db import create_session
 from airflow.utils.state import State
 from airflow.utils.weight_rule import WeightRule
 from tests.models import DEFAULT_DATE
@@ -941,3 +942,17 @@ class DagTest(unittest.TestCase):
             assert issubclass(PendingDeprecationWarning, warning.category)
 
             self.assertEqual(dag.task_dict, {t1.task_id: t1})
+
+    def test_get_paused_dag_ids(self):
+        dag_id = "test_get_paused_dag_ids"
+        dag = DAG(dag_id, is_paused_upon_creation=True)
+        dag.sync_to_db()
+        self.assertIsNotNone(DagModel.get_dagmodel(dag_id))
+
+        paused_dag_ids = DagModel.get_paused_dag_ids([dag_id])
+        self.assertEqual(paused_dag_ids, {dag_id})
+
+        with create_session() as session:
+            session.query(DagModel).filter(
+                DagModel.dag_id == dag_id).delete(
+                synchronize_session=False)

Reply via email to