Repository: incubator-airflow Updated Branches: refs/heads/master 2e3f07ff9 -> e6d3160a0
[AIRFLOW-1140] DatabricksSubmitRunOperator should template the "json" field. Add "json" in the templated_fields list for the DatabricksSubmitRunOperator. Closes #2255 from andrewmchen/DatabricksOperator- templated Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/e6d3160a Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/e6d3160a Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/e6d3160a Branch: refs/heads/master Commit: e6d3160a061dbaa6042d524095dcd1cbc15e0bcd Parents: 2e3f07f Author: Andrew Chen <[email protected]> Authored: Mon May 1 23:24:24 2017 +0200 Committer: Bolke de Bruin <[email protected]> Committed: Mon May 1 23:24:24 2017 +0200 ---------------------------------------------------------------------- .../contrib/operators/databricks_operator.py | 60 ++++++++++++++++++-- docs/concepts.rst | 2 + .../operators/test_databricks_operator.py | 57 +++++++++++++++---- 3 files changed, 102 insertions(+), 17 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e6d3160a/airflow/contrib/operators/databricks_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py index 46b1659..b4ce502 100644 --- a/airflow/contrib/operators/databricks_operator.py +++ b/airflow/contrib/operators/databricks_operator.py @@ -14,6 +14,7 @@ # import logging +import six import time from airflow.exceptions import AirflowException @@ -81,34 +82,53 @@ class DatabricksSubmitRunOperator(BaseOperator): (i.e. ``spark_jar_task``, ``notebook_task``..) to this operator will be merged with this json dictionary if they are provided. If there are conflicts during the merge, the named parameters will - take precedence and override the top level json keys. - https://docs.databricks.com/api/latest/jobs.html#runs-submit + take precedence and override the top level json keys. This field will be + templated. + + .. seealso:: + For more information about templating see :ref:`jinja-templating`. + https://docs.databricks.com/api/latest/jobs.html#runs-submit :type json: dict :param spark_jar_task: The main class and parameters for the JAR task. Note that the actual JAR is specified in the ``libraries``. *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` should be specified. - https://docs.databricks.com/api/latest/jobs.html#jobssparkjartask + This field will be templated. + + .. seealso:: + https://docs.databricks.com/api/latest/jobs.html#jobssparkjartask :type spark_jar_task: dict :param notebook_task: The notebook path and parameters for the notebook task. *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` should be specified. - https://docs.databricks.com/api/latest/jobs.html#jobsnotebooktask + This field will be templated. + + .. seealso:: + https://docs.databricks.com/api/latest/jobs.html#jobsnotebooktask :type notebook_task: dict :param new_cluster: Specs for a new cluster on which this task will be run. *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified. - https://docs.databricks.com/api/latest/jobs.html#jobsclusterspecnewcluster + This field will be templated. + + .. seealso:: + https://docs.databricks.com/api/latest/jobs.html#jobsclusterspecnewcluster :type new_cluster: dict :param existing_cluster_id: ID for existing cluster on which to run this task. *EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified. + This field will be templated. :type existing_cluster_id: string :param libraries: Libraries which this run will use. - https://docs.databricks.com/api/latest/libraries.html#managedlibrarieslibrary + This field will be templated. + + .. seealso:: + https://docs.databricks.com/api/latest/libraries.html#managedlibrarieslibrary :type libraries: list of dicts :param run_name: The run name used for this task. By default this will be set to the Airflow ``task_id``. This ``task_id`` is a required parameter of the superclass ``BaseOperator``. + This field will be templated. :type run_name: string :param timeout_seconds: The timeout for this run. By default a value of 0 is used which means to have no timeout. + This field will be templated. :type timeout_seconds: int32 :param databricks_conn_id: The name of the Airflow connection to use. By default and in the common case this will be ``databricks_default``. @@ -120,6 +140,8 @@ class DatabricksSubmitRunOperator(BaseOperator): unreachable. Its value must be greater than or equal to 1. :type databricks_retry_limit: int """ + # Used in airflow.models.BaseOperator + template_fields = ('json',) # Databricks brand color (blue) under white text ui_color = '#1CB1C2' ui_fgcolor = '#fff' @@ -163,9 +185,35 @@ class DatabricksSubmitRunOperator(BaseOperator): if 'run_name' not in self.json: self.json['run_name'] = run_name or kwargs['task_id'] + self.json = self._deep_string_coerce(self.json) # This variable will be used in case our task gets killed. self.run_id = None + def _deep_string_coerce(self, content, json_path='json'): + """ + Coerces content or all values of content if it is a dict to a string. The + function will throw if content contains non-string or non-numeric types. + + The reason why we have this function is because the ``self.json`` field must be a dict + with only string values. This is because ``render_template`` will fail for numerical values. + """ + c = self._deep_string_coerce + if isinstance(content, six.string_types): + return content + elif isinstance(content, six.integer_types+(float,)): + # Databricks can tolerate either numeric or string types in the API backend. + return str(content) + elif isinstance(content, (list, tuple)): + return [c(e, '{0}[{1}]'.format(json_path, i)) for e, i in enumerate(content)] + elif isinstance(content, dict): + return {k: c(v, '{0}[{1}]'.format(json_path, k)) + for k, v in list(content.items())} + else: + param_type = type(content) + msg = 'Type {0} used for parameter {1} is not a number or a string' \ + .format(param_type, json_path) + raise AirflowException(msg) + def _log_run_page_url(self, url): logging.info('View run status, Spark UI, and logs at {}'.format(url)) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e6d3160a/docs/concepts.rst ---------------------------------------------------------------------- diff --git a/docs/concepts.rst b/docs/concepts.rst index 2760a6f..33a6ea4 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -755,6 +755,8 @@ to the related tasks in Airflow. This content will get rendered as markdown respectively in the "Graph View" and "Task Details" pages. +.. _jinja-templating: + Jinja Templating ================ http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/e6d3160a/tests/contrib/operators/test_databricks_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_databricks_operator.py b/tests/contrib/operators/test_databricks_operator.py index aab47fa..36932bd 100644 --- a/tests/contrib/operators/test_databricks_operator.py +++ b/tests/contrib/operators/test_databricks_operator.py @@ -13,11 +13,14 @@ # limitations under the License. # +import jinja2 import unittest +from datetime import datetime from airflow.contrib.hooks.databricks_hook import RunState from airflow.contrib.operators.databricks_operator import DatabricksSubmitRunOperator from airflow.exceptions import AirflowException +from airflow.models import DAG try: from unittest import mock @@ -27,11 +30,18 @@ except ImportError: except ImportError: mock = None +DATE = '2017-04-20' TASK_ID = 'databricks-operator' DEFAULT_CONN_ID = 'databricks_default' NOTEBOOK_TASK = { 'notebook_path': '/test' } +TEMPLATED_NOTEBOOK_TASK = { + 'notebook_path': '/test-{{ ds }}' +} +RENDERED_TEMPLATED_NOTEBOOK_TASK = { + 'notebook_path': '/test-{0}'.format(DATE) +} SPARK_JAR_TASK = { 'main_class_name': 'com.databricks.Test' } @@ -51,11 +61,11 @@ class DatabricksSubmitRunOperatorTest(unittest.TestCase): Test the initializer with the named parameters. """ op = DatabricksSubmitRunOperator(task_id=TASK_ID, new_cluster=NEW_CLUSTER, notebook_task=NOTEBOOK_TASK) - expected = { + expected = op._deep_string_coerce({ 'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID - } + }) self.assertDictEqual(expected, op.json) def test_init_with_json(self): @@ -67,11 +77,11 @@ class DatabricksSubmitRunOperatorTest(unittest.TestCase): 'notebook_task': NOTEBOOK_TASK } op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - expected = { + expected = op._deep_string_coerce({ 'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID - } + }) self.assertDictEqual(expected, op.json) def test_init_with_specified_run_name(self): @@ -84,11 +94,11 @@ class DatabricksSubmitRunOperatorTest(unittest.TestCase): 'run_name': RUN_NAME } op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) - expected = { + expected = op._deep_string_coerce({ 'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': RUN_NAME - } + }) self.assertDictEqual(expected, op.json) def test_init_with_merging(self): @@ -103,13 +113,38 @@ class DatabricksSubmitRunOperatorTest(unittest.TestCase): 'notebook_task': NOTEBOOK_TASK, } op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json, new_cluster=override_new_cluster) - expected = { + expected = op._deep_string_coerce({ 'new_cluster': override_new_cluster, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID, + }) + self.assertDictEqual(expected, op.json) + + def test_init_with_templating(self): + json = { + 'new_cluster': NEW_CLUSTER, + 'notebook_task': TEMPLATED_NOTEBOOK_TASK, } + dag = DAG('test', start_date=datetime.now()) + op = DatabricksSubmitRunOperator(dag=dag, task_id=TASK_ID, json=json) + op.json = op.render_template('json', op.json, {'ds': DATE}) + expected = op._deep_string_coerce({ + 'new_cluster': NEW_CLUSTER, + 'notebook_task': RENDERED_TEMPLATED_NOTEBOOK_TASK, + 'run_name': TASK_ID, + }) self.assertDictEqual(expected, op.json) + def test_init_with_bad_type(self): + json = { + 'test': datetime.now() + } + # Looks a bit weird since we have to escape regex reserved symbols. + exception_message = 'Type \<(type|class) \'datetime.datetime\'\> used ' + \ + 'for parameter json\[test\] is not a number or a string' + with self.assertRaisesRegexp(AirflowException, exception_message): + op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=json) + @mock.patch('airflow.contrib.operators.databricks_operator.DatabricksHook') def test_exec_success(self, db_mock_class): """ @@ -126,11 +161,11 @@ class DatabricksSubmitRunOperatorTest(unittest.TestCase): op.execute(None) - expected = { + expected = op._deep_string_coerce({ 'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID - } + }) db_mock_class.assert_called_once_with( DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit) @@ -156,11 +191,11 @@ class DatabricksSubmitRunOperatorTest(unittest.TestCase): with self.assertRaises(AirflowException): op.execute(None) - expected = { + expected = op._deep_string_coerce({ 'new_cluster': NEW_CLUSTER, 'notebook_task': NOTEBOOK_TASK, 'run_name': TASK_ID, - } + }) db_mock_class.assert_called_once_with( DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit)
