Repository: incubator-airflow
Updated Branches:
  refs/heads/master 2e3f07ff9 -> e6d3160a0


[AIRFLOW-1140] DatabricksSubmitRunOperator should template the "json" field.

Add "json" in the templated_fields list for the
DatabricksSubmitRunOperator.

Closes #2255 from andrewmchen/DatabricksOperator-
templated


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/e6d3160a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/e6d3160a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/e6d3160a

Branch: refs/heads/master
Commit: e6d3160a061dbaa6042d524095dcd1cbc15e0bcd
Parents: 2e3f07f
Author: Andrew Chen <[email protected]>
Authored: Mon May 1 23:24:24 2017 +0200
Committer: Bolke de Bruin <[email protected]>
Committed: Mon May 1 23:24:24 2017 +0200

----------------------------------------------------------------------
 .../contrib/operators/databricks_operator.py    | 60 ++++++++++++++++++--
 docs/concepts.rst                               |  2 +
 .../operators/test_databricks_operator.py       | 57 +++++++++++++++----
 3 files changed, 102 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e6d3160a/airflow/contrib/operators/databricks_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/databricks_operator.py 
b/airflow/contrib/operators/databricks_operator.py
index 46b1659..b4ce502 100644
--- a/airflow/contrib/operators/databricks_operator.py
+++ b/airflow/contrib/operators/databricks_operator.py
@@ -14,6 +14,7 @@
 #
 
 import logging
+import six
 import time
 
 from airflow.exceptions import AirflowException
@@ -81,34 +82,53 @@ class DatabricksSubmitRunOperator(BaseOperator):
         (i.e. ``spark_jar_task``, ``notebook_task``..) to this operator will
         be merged with this json dictionary if they are provided.
         If there are conflicts during the merge, the named parameters will
-        take precedence and override the top level json keys.
-        https://docs.databricks.com/api/latest/jobs.html#runs-submit
+        take precedence and override the top level json keys. This field will 
be
+        templated.
+
+        .. seealso::
+            For more information about templating see :ref:`jinja-templating`.
+            https://docs.databricks.com/api/latest/jobs.html#runs-submit
     :type json: dict
     :param spark_jar_task: The main class and parameters for the JAR task. 
Note that
         the actual JAR is specified in the ``libraries``.
         *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` should be specified.
-        https://docs.databricks.com/api/latest/jobs.html#jobssparkjartask
+        This field will be templated.
+
+        .. seealso::
+            https://docs.databricks.com/api/latest/jobs.html#jobssparkjartask
     :type spark_jar_task: dict
     :param notebook_task: The notebook path and parameters for the notebook 
task.
         *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` should be specified.
-        https://docs.databricks.com/api/latest/jobs.html#jobsnotebooktask
+        This field will be templated.
+
+        .. seealso::
+            https://docs.databricks.com/api/latest/jobs.html#jobsnotebooktask
     :type notebook_task: dict
     :param new_cluster: Specs for a new cluster on which this task will be run.
         *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be 
specified.
-        
https://docs.databricks.com/api/latest/jobs.html#jobsclusterspecnewcluster
+        This field will be templated.
+
+        .. seealso::
+            
https://docs.databricks.com/api/latest/jobs.html#jobsclusterspecnewcluster
     :type new_cluster: dict
     :param existing_cluster_id: ID for existing cluster on which to run this 
task.
         *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be 
specified.
+        This field will be templated.
     :type existing_cluster_id: string
     :param libraries: Libraries which this run will use.
-        
https://docs.databricks.com/api/latest/libraries.html#managedlibrarieslibrary
+        This field will be templated.
+
+        .. seealso::
+            
https://docs.databricks.com/api/latest/libraries.html#managedlibrarieslibrary
     :type libraries: list of dicts
     :param run_name: The run name used for this task.
         By default this will be set to the Airflow ``task_id``. This 
``task_id`` is a
         required parameter of the superclass ``BaseOperator``.
+        This field will be templated.
     :type run_name: string
     :param timeout_seconds: The timeout for this run. By default a value of 0 
is used
         which means to have no timeout.
+        This field will be templated.
     :type timeout_seconds: int32
     :param databricks_conn_id: The name of the Airflow connection to use.
         By default and in the common case this will be ``databricks_default``.
@@ -120,6 +140,8 @@ class DatabricksSubmitRunOperator(BaseOperator):
         unreachable. Its value must be greater than or equal to 1.
     :type databricks_retry_limit: int
     """
+    # Used in airflow.models.BaseOperator
+    template_fields = ('json',)
     # Databricks brand color (blue) under white text
     ui_color = '#1CB1C2'
     ui_fgcolor = '#fff'
@@ -163,9 +185,35 @@ class DatabricksSubmitRunOperator(BaseOperator):
         if 'run_name' not in self.json:
             self.json['run_name'] = run_name or kwargs['task_id']
 
+        self.json = self._deep_string_coerce(self.json)
         # This variable will be used in case our task gets killed.
         self.run_id = None
 
+    def _deep_string_coerce(self, content, json_path='json'):
+        """
+        Coerces content or all values of content if it is a dict to a string. 
The
+        function will throw if content contains non-string or non-numeric 
types.
+
+        The reason why we have this function is because the ``self.json`` 
field must be a dict
+        with only string values. This is because ``render_template`` will fail 
for numerical values.
+        """
+        c = self._deep_string_coerce
+        if isinstance(content, six.string_types):
+            return content
+        elif isinstance(content, six.integer_types+(float,)):
+            # Databricks can tolerate either numeric or string types in the 
API backend.
+            return str(content)
+        elif isinstance(content, (list, tuple)):
+            return [c(e, '{0}[{1}]'.format(json_path, i)) for e, i in 
enumerate(content)]
+        elif isinstance(content, dict):
+            return {k: c(v, '{0}[{1}]'.format(json_path, k))
+                    for k, v in list(content.items())}
+        else:
+            param_type = type(content)
+            msg = 'Type {0} used for parameter {1} is not a number or a 
string' \
+                    .format(param_type, json_path)
+            raise AirflowException(msg)
+
     def _log_run_page_url(self, url):
         logging.info('View run status, Spark UI, and logs at {}'.format(url))
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e6d3160a/docs/concepts.rst
----------------------------------------------------------------------
diff --git a/docs/concepts.rst b/docs/concepts.rst
index 2760a6f..33a6ea4 100644
--- a/docs/concepts.rst
+++ b/docs/concepts.rst
@@ -755,6 +755,8 @@ to the related tasks in Airflow.
 This content will get rendered as markdown respectively in the "Graph View" and
 "Task Details" pages.
 
+.. _jinja-templating:
+
 Jinja Templating
 ================
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e6d3160a/tests/contrib/operators/test_databricks_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_databricks_operator.py 
b/tests/contrib/operators/test_databricks_operator.py
index aab47fa..36932bd 100644
--- a/tests/contrib/operators/test_databricks_operator.py
+++ b/tests/contrib/operators/test_databricks_operator.py
@@ -13,11 +13,14 @@
 # limitations under the License.
 #
 
+import jinja2
 import unittest
+from datetime import datetime
 
 from airflow.contrib.hooks.databricks_hook import RunState
 from airflow.contrib.operators.databricks_operator import 
DatabricksSubmitRunOperator
 from airflow.exceptions import AirflowException
+from airflow.models import DAG
 
 try:
     from unittest import mock
@@ -27,11 +30,18 @@ except ImportError:
     except ImportError:
         mock = None
 
+DATE = '2017-04-20'
 TASK_ID = 'databricks-operator'
 DEFAULT_CONN_ID = 'databricks_default'
 NOTEBOOK_TASK = {
     'notebook_path': '/test'
 }
+TEMPLATED_NOTEBOOK_TASK = {
+    'notebook_path': '/test-{{ ds }}'
+}
+RENDERED_TEMPLATED_NOTEBOOK_TASK = {
+    'notebook_path': '/test-{0}'.format(DATE)
+}
 SPARK_JAR_TASK = {
     'main_class_name': 'com.databricks.Test'
 }
@@ -51,11 +61,11 @@ class DatabricksSubmitRunOperatorTest(unittest.TestCase):
         Test the initializer with the named parameters.
         """
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, 
new_cluster=NEW_CLUSTER, notebook_task=NOTEBOOK_TASK)
-        expected = {
+        expected = op._deep_string_coerce({
           'new_cluster': NEW_CLUSTER,
           'notebook_task': NOTEBOOK_TASK,
           'run_name': TASK_ID
-        }
+        })
         self.assertDictEqual(expected, op.json)
 
     def test_init_with_json(self):
@@ -67,11 +77,11 @@ class DatabricksSubmitRunOperatorTest(unittest.TestCase):
           'notebook_task': NOTEBOOK_TASK
         }
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
-        expected = {
+        expected = op._deep_string_coerce({
           'new_cluster': NEW_CLUSTER,
           'notebook_task': NOTEBOOK_TASK,
           'run_name': TASK_ID
-        }
+        })
         self.assertDictEqual(expected, op.json)
 
     def test_init_with_specified_run_name(self):
@@ -84,11 +94,11 @@ class DatabricksSubmitRunOperatorTest(unittest.TestCase):
           'run_name': RUN_NAME
         }
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
-        expected = {
+        expected = op._deep_string_coerce({
           'new_cluster': NEW_CLUSTER,
           'notebook_task': NOTEBOOK_TASK,
           'run_name': RUN_NAME
-        }
+        })
         self.assertDictEqual(expected, op.json)
 
     def test_init_with_merging(self):
@@ -103,13 +113,38 @@ class DatabricksSubmitRunOperatorTest(unittest.TestCase):
           'notebook_task': NOTEBOOK_TASK,
         }
         op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, 
new_cluster=override_new_cluster)
-        expected = {
+        expected = op._deep_string_coerce({
           'new_cluster': override_new_cluster,
           'notebook_task': NOTEBOOK_TASK,
           'run_name': TASK_ID,
+        })
+        self.assertDictEqual(expected, op.json)
+
+    def test_init_with_templating(self):
+        json = {
+          'new_cluster': NEW_CLUSTER,
+          'notebook_task': TEMPLATED_NOTEBOOK_TASK,
         }
+        dag = DAG('test', start_date=datetime.now())
+        op = DatabricksSubmitRunOperator(dag=dag, task_id=TASK_ID, json=json)
+        op.json = op.render_template('json', op.json, {'ds': DATE})
+        expected = op._deep_string_coerce({
+          'new_cluster': NEW_CLUSTER,
+          'notebook_task': RENDERED_TEMPLATED_NOTEBOOK_TASK,
+          'run_name': TASK_ID,
+        })
         self.assertDictEqual(expected, op.json)
 
+    def test_init_with_bad_type(self):
+        json = {
+            'test': datetime.now()
+        }
+        # Looks a bit weird since we have to escape regex reserved symbols.
+        exception_message = 'Type \<(type|class) \'datetime.datetime\'\> used 
' + \
+                        'for parameter json\[test\] is not a number or a 
string'
+        with self.assertRaisesRegexp(AirflowException, exception_message):
+            op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json)
+
     @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook')
     def test_exec_success(self, db_mock_class):
         """
@@ -126,11 +161,11 @@ class DatabricksSubmitRunOperatorTest(unittest.TestCase):
 
         op.execute(None)
 
-        expected = {
+        expected = op._deep_string_coerce({
           'new_cluster': NEW_CLUSTER,
           'notebook_task': NOTEBOOK_TASK,
           'run_name': TASK_ID
-        }
+        })
         db_mock_class.assert_called_once_with(
                 DEFAULT_CONN_ID,
                 retry_limit=op.databricks_retry_limit)
@@ -156,11 +191,11 @@ class DatabricksSubmitRunOperatorTest(unittest.TestCase):
         with self.assertRaises(AirflowException):
             op.execute(None)
 
-        expected = {
+        expected = op._deep_string_coerce({
           'new_cluster': NEW_CLUSTER,
           'notebook_task': NOTEBOOK_TASK,
           'run_name': TASK_ID,
-        }
+        })
         db_mock_class.assert_called_once_with(
                 DEFAULT_CONN_ID,
                 retry_limit=op.databricks_retry_limit)

Reply via email to