Repository: incubator-airflow
Updated Branches:
  refs/heads/master dc38b2f46 -> 2090011bb


[AIRFLOW-2355] Airflow trigger tag parameters in subdag

Parameters passed through airflow trigger_dag -c
'{"text": "blah"}' can be accessed through
{{ dag_run.conf.text }} in subdag.

Closes #3460 from milton0825/trigger-dag-subdag


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/2090011b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/2090011b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/2090011b

Branch: refs/heads/master
Commit: 2090011bb7b2619bb6981904f35c2d30b13de64f
Parents: dc38b2f
Author: milton0825 <[email protected]>
Authored: Sun Jun 17 19:19:05 2018 +0100
Committer: Kaxil Naik <[email protected]>
Committed: Sun Jun 17 19:19:05 2018 +0100

----------------------------------------------------------------------
 airflow/api/common/experimental/trigger_dag.py  | 58 +++++++++---
 .../common/experimental/trigger_dag_tests.py    | 93 ++++++++++++++++++++
 2 files changed, 139 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/2090011b/airflow/api/common/experimental/trigger_dag.py
----------------------------------------------------------------------
diff --git a/airflow/api/common/experimental/trigger_dag.py 
b/airflow/api/common/experimental/trigger_dag.py
index 67c43e1..fd9b51f 100644
--- a/airflow/api/common/experimental/trigger_dag.py
+++ b/airflow/api/common/experimental/trigger_dag.py
@@ -25,14 +25,19 @@ from airflow.utils import timezone
 from airflow.utils.state import State
 
 
-def trigger_dag(dag_id, run_id=None, conf=None, execution_date=None,
-                replace_microseconds=True):
-    dagbag = DagBag()
-
-    if dag_id not in dagbag.dags:
+def _trigger_dag(
+        dag_id,
+        dag_bag,
+        dag_run,
+        run_id,
+        conf,
+        execution_date,
+        replace_microseconds,
+):
+    if dag_id not in dag_bag.dags:
         raise AirflowException("Dag id {} not found".format(dag_id))
 
-    dag = dagbag.get_dag(dag_id)
+    dag = dag_bag.get_dag(dag_id)
 
     if not execution_date:
         execution_date = timezone.utcnow()
@@ -45,7 +50,7 @@ def trigger_dag(dag_id, run_id=None, conf=None, 
execution_date=None,
     if not run_id:
         run_id = "manual__{0}".format(execution_date.isoformat())
 
-    dr = DagRun.find(dag_id=dag_id, run_id=run_id)
+    dr = dag_run.find(dag_id=dag_id, run_id=run_id)
     if dr:
         raise AirflowException("Run id {} already exists for dag id {}".format(
             run_id,
@@ -56,12 +61,41 @@ def trigger_dag(dag_id, run_id=None, conf=None, 
execution_date=None,
     if conf:
         run_conf = json.loads(conf)
 
-    trigger = dag.create_dagrun(
+    triggers = list()
+    dags_to_trigger = list()
+    dags_to_trigger.append(dag)
+    while dags_to_trigger:
+        dag = dags_to_trigger.pop()
+        trigger = dag.create_dagrun(
+            run_id=run_id,
+            execution_date=execution_date,
+            state=State.RUNNING,
+            conf=run_conf,
+            external_trigger=True,
+        )
+        triggers.append(trigger)
+        if dag.subdags:
+            dags_to_trigger.extend(dag.subdags)
+    return triggers
+
+
+def trigger_dag(
+        dag_id,
+        run_id=None,
+        conf=None,
+        execution_date=None,
+        replace_microseconds=True,
+):
+    dagbag = DagBag()
+    dag_run = DagRun()
+    triggers = _trigger_dag(
+        dag_id=dag_id,
+        dag_run=dag_run,
+        dag_bag=dagbag,
         run_id=run_id,
+        conf=conf,
         execution_date=execution_date,
-        state=State.RUNNING,
-        conf=run_conf,
-        external_trigger=True
+        replace_microseconds=replace_microseconds,
     )
 
-    return trigger
+    return triggers[0] if triggers else None

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/2090011b/tests/api/common/experimental/trigger_dag_tests.py
----------------------------------------------------------------------
diff --git a/tests/api/common/experimental/trigger_dag_tests.py 
b/tests/api/common/experimental/trigger_dag_tests.py
new file mode 100644
index 0000000..d635484
--- /dev/null
+++ b/tests/api/common/experimental/trigger_dag_tests.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mock
+import unittest
+
+from airflow.exceptions import AirflowException
+from airflow.models import DAG, DagRun
+from airflow.api.common.experimental.trigger_dag import _trigger_dag
+
+
+class TriggerDagTests(unittest.TestCase):
+
+    @mock.patch('airflow.models.DagRun')
+    @mock.patch('airflow.models.DagBag')
+    def test_trigger_dag_dag_not_found(self, dag_bag_mock, dag_run_mock):
+        dag_bag_mock.dags = []
+        self.assertRaises(
+            AirflowException,
+            _trigger_dag,
+            'dag_not_found',
+            dag_bag_mock,
+            dag_run_mock,
+            run_id=None,
+            conf=None,
+            execution_date=None,
+            replace_microseconds=True,
+        )
+
+    @mock.patch('airflow.models.DagRun')
+    @mock.patch('airflow.models.DagBag')
+    def test_trigger_dag_dag_run_exist(self, dag_bag_mock, dag_run_mock):
+        dag_id = "dag_run_exist"
+        dag = DAG(dag_id)
+        dag_bag_mock.dags = [dag_id]
+        dag_bag_mock.get_dag.return_value = dag
+        dag_run_mock.find.return_value = DagRun()
+        self.assertRaises(
+            AirflowException,
+            _trigger_dag,
+            dag_id,
+            dag_bag_mock,
+            dag_run_mock,
+            run_id=None,
+            conf=None,
+            execution_date=None,
+            replace_microseconds=True,
+        )
+
+    @mock.patch('airflow.models.DAG')
+    @mock.patch('airflow.models.DagRun')
+    @mock.patch('airflow.models.DagBag')
+    def test_trigger_dag_include_subdags(self, dag_bag_mock, dag_run_mock, 
dag_mock):
+        dag_id = "trigger_dag"
+        dag_bag_mock.dags = [dag_id]
+        dag_bag_mock.get_dag.return_value = dag_mock
+        dag_run_mock.find.return_value = None
+        dag1 = mock.MagicMock()
+        dag1.subdags = []
+        dag2 = mock.MagicMock()
+        dag2.subdags = []
+        dag_mock.subdags = [dag1, dag2]
+
+        triggers = _trigger_dag(
+            dag_id,
+            dag_bag_mock,
+            dag_run_mock,
+            run_id=None,
+            conf=None,
+            execution_date=None,
+            replace_microseconds=True)
+
+        self.assertEqual(3, len(triggers))
+
+
+if __name__ == '__main__':
+    unittest.main()

Reply via email to