[AIRFLOW-1191] : SparkSubmitHook custom cmd Add the capability to set the spark-submit binary to call. The default behaviour set the spark-submit command to 'spark-submit', or to set it via a Spark env var. the spark binary can now be set in the spark connection.
Test coverage extended for the new settings. Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/d06ab68f Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/d06ab68f Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/d06ab68f Branch: refs/heads/master Commit: d06ab68f2c83ad5dce3cae1c5aa9a9a9f32cf934 Parents: d165377 Author: vfoucault <[email protected]> Authored: Sun May 14 23:23:11 2017 +0200 Committer: vfoucault <[email protected]> Committed: Mon May 22 23:43:46 2017 +0200 ---------------------------------------------------------------------- airflow/contrib/hooks/spark_submit_hook.py | 49 ++++++------ tests/contrib/hooks/test_spark_submit_hook.py | 88 ++++++++++++++++++++-- 2 files changed, 104 insertions(+), 33 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/d06ab68f/airflow/contrib/hooks/spark_submit_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py index 208b74f..ae51959 100644 --- a/airflow/contrib/hooks/spark_submit_hook.py +++ b/airflow/contrib/hooks/spark_submit_hook.py @@ -100,40 +100,39 @@ class SparkSubmitHook(BaseHook): self._sp = None self._yarn_application_id = None - (self._master, self._queue, self._deploy_mode, self._spark_home) = self._resolve_connection() - self._is_yarn = 'yarn' in self._master + self._connection = self._resolve_connection() + self._is_yarn = 'yarn' in self._connection['master'] def _resolve_connection(self): # Build from connection master or default to yarn if not available - master = 'yarn' - queue = None - deploy_mode = None - spark_home = None + conn_data = {'master': 'yarn', + 'queue': None, + 'deploy_mode': None, + 'spark_home': None, + 'spark_binary': 'spark-submit'} try: # Master can be local, yarn, spark://HOST:PORT or mesos://HOST:PORT conn = self.get_connection(self._conn_id) if conn.port: - master = "{}:{}".format(conn.host, conn.port) + conn_data['master'] = "{}:{}".format(conn.host, conn.port) else: - master = conn.host + conn_data['master'] = conn.host # Determine optional yarn queue from the extra field extra = conn.extra_dejson - if 'queue' in extra: - queue = extra['queue'] - if 'deploy-mode' in extra: - deploy_mode = extra['deploy-mode'] - if 'spark-home' in extra: - spark_home = extra['spark-home'] + conn_data['queue'] = extra.get('queue', None) + conn_data['deploy_mode'] = extra.get('deploy-mode', None) + conn_data['spark_home'] = extra.get('spark-home', None) + conn_data['spark_binary'] = extra.get('spark-binary', 'spark-submit') except AirflowException: logging.debug( "Could not load connection string {}, defaulting to {}".format( - self._conn_id, master + self._conn_id, conn_data['master'] ) ) - return master, queue, deploy_mode, spark_home + return conn_data def get_conn(self): pass @@ -148,13 +147,13 @@ class SparkSubmitHook(BaseHook): # If the spark_home is passed then build the spark-submit executable path using # the spark_home; otherwise assume that spark-submit is present in the path to # the executing user - if self._spark_home: - connection_cmd = [os.path.join(self._spark_home, 'bin', 'spark-submit')] + if self._connection['spark_home']: + connection_cmd = [os.path.join(self._connection['spark_home'], 'bin', self._connection['spark_binary'])] else: - connection_cmd = ['spark-submit'] + connection_cmd = [self._connection['spark_binary']] # The url ot the spark master - connection_cmd += ["--master", self._master] + connection_cmd += ["--master", self._connection['master']] if self._conf: for key in self._conf: @@ -185,10 +184,10 @@ class SparkSubmitHook(BaseHook): connection_cmd += ["--class", self._java_class] if self._verbose: connection_cmd += ["--verbose"] - if self._queue: - connection_cmd += ["--queue", self._queue] - if self._deploy_mode: - connection_cmd += ["--deploy-mode", self._deploy_mode] + if self._connection['queue']: + connection_cmd += ["--queue", self._connection['queue']] + if self._connection['deploy_mode']: + connection_cmd += ["--deploy-mode", self._connection['deploy_mode']] # The actual script to execute connection_cmd += [application] @@ -245,7 +244,7 @@ class SparkSubmitHook(BaseHook): line = line.decode('utf-8').strip() # If we run yarn cluster mode, we want to extract the application id from # the logs so we can kill the application when we stop it unexpectedly - if self._is_yarn and self._deploy_mode == 'cluster': + if self._is_yarn and self._connection['deploy_mode'] == 'cluster': match = re.search('(application[0-9_]+)', line) if match: self._yarn_application_id = match.groups()[0] http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/d06ab68f/tests/contrib/hooks/test_spark_submit_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_spark_submit_hook.py b/tests/contrib/hooks/test_spark_submit_hook.py index ee5b9e0..80b5ce0 100644 --- a/tests/contrib/hooks/test_spark_submit_hook.py +++ b/tests/contrib/hooks/test_spark_submit_hook.py @@ -90,6 +90,16 @@ class TestSparkSubmitHook(unittest.TestCase): conn_id='spark_home_not_set', conn_type='spark', host='yarn://yarn-master') ) + db.merge_conn( + models.Connection( + conn_id='spark_binary_set', conn_type='spark', + host='yarn', extra='{"spark-binary": "custom-spark-submit"}') + ) + db.merge_conn( + models.Connection( + conn_id='spark_binary_and_home_set', conn_type='spark', + host='yarn', extra='{"spark-home": "/path/to/spark_home", "spark-binary": "custom-spark-submit"}') + ) def test_build_command(self): # Given @@ -123,8 +133,6 @@ class TestSparkSubmitHook(unittest.TestCase): ] self.assertEquals(expected_build_cmd, cmd) - - @patch('subprocess.Popen') def test_SparkProcess_runcmd(self, mock_popen): # Given @@ -150,7 +158,12 @@ class TestSparkSubmitHook(unittest.TestCase): # Then dict_cmd = self.cmd_args_to_dict(cmd) - self.assertSequenceEqual(connection, ('yarn', None, None, None)) + expected_spark_connection = {"master": u"yarn", + "spark_binary": "spark-submit", + "deploy_mode": None, + "queue": None, + "spark_home": None} + self.assertEqual(connection, expected_spark_connection) self.assertEqual(dict_cmd["--master"], "yarn") def test_resolve_connection_yarn_default_connection(self): @@ -163,7 +176,12 @@ class TestSparkSubmitHook(unittest.TestCase): # Then dict_cmd = self.cmd_args_to_dict(cmd) - self.assertSequenceEqual(connection, ('yarn', 'root.default', None, None)) + expected_spark_connection = {"master": u"yarn", + "spark_binary": "spark-submit", + "deploy_mode": None, + "queue": u"root.default", + "spark_home": None} + self.assertEqual(connection, expected_spark_connection) self.assertEqual(dict_cmd["--master"], "yarn") self.assertEqual(dict_cmd["--queue"], "root.default") @@ -177,7 +195,12 @@ class TestSparkSubmitHook(unittest.TestCase): # Then dict_cmd = self.cmd_args_to_dict(cmd) - self.assertSequenceEqual(connection, ('mesos://host:5050', None, None, None)) + expected_spark_connection = {"master": u"mesos://host:5050", + "spark_binary": "spark-submit", + "deploy_mode": None, + "queue": None, + "spark_home": None} + self.assertEqual(connection, expected_spark_connection) self.assertEqual(dict_cmd["--master"], "mesos://host:5050") def test_resolve_connection_spark_yarn_cluster_connection(self): @@ -190,7 +213,12 @@ class TestSparkSubmitHook(unittest.TestCase): # Then dict_cmd = self.cmd_args_to_dict(cmd) - self.assertSequenceEqual(connection, ('yarn://yarn-master', 'root.etl', 'cluster', None)) + expected_spark_connection = {"master": u"yarn://yarn-master", + "spark_binary": "spark-submit", + "deploy_mode": u"cluster", + "queue": u"root.etl", + "spark_home": None} + self.assertEqual(connection, expected_spark_connection) self.assertEqual(dict_cmd["--master"], "yarn://yarn-master") self.assertEqual(dict_cmd["--queue"], "root.etl") self.assertEqual(dict_cmd["--deploy-mode"], "cluster") @@ -204,7 +232,12 @@ class TestSparkSubmitHook(unittest.TestCase): cmd = hook._build_command(self._spark_job_file) # Then - self.assertSequenceEqual(connection, ('yarn://yarn-master', None, None, '/opt/myspark')) + expected_spark_connection = {"master": u"yarn://yarn-master", + "spark_binary": "spark-submit", + "deploy_mode": None, + "queue": None, + "spark_home": u"/opt/myspark"} + self.assertEqual(connection, expected_spark_connection) self.assertEqual(cmd[0], '/opt/myspark/bin/spark-submit') def test_resolve_connection_spark_home_not_set_connection(self): @@ -216,9 +249,48 @@ class TestSparkSubmitHook(unittest.TestCase): cmd = hook._build_command(self._spark_job_file) # Then - self.assertSequenceEqual(connection, ('yarn://yarn-master', None, None, None)) + expected_spark_connection = {"master": u"yarn://yarn-master", + "spark_binary": "spark-submit", + "deploy_mode": None, + "queue": None, + "spark_home": None} + self.assertEqual(connection, expected_spark_connection) self.assertEqual(cmd[0], 'spark-submit') + def test_resolve_connection_spark_binary_set_connection(self): + # Given + hook = SparkSubmitHook(conn_id='spark_binary_set') + + # When + connection = hook._resolve_connection() + cmd = hook._build_command(self._spark_job_file) + + # Then + expected_spark_connection = {"master": u"yarn", + "spark_binary": u"custom-spark-submit", + "deploy_mode": None, + "queue": None, + "spark_home": None} + self.assertEqual(connection, expected_spark_connection) + self.assertEqual(cmd[0], 'custom-spark-submit') + + def test_resolve_connection_spark_binary_and_home_set_connection(self): + # Given + hook = SparkSubmitHook(conn_id='spark_binary_and_home_set') + + # When + connection = hook._resolve_connection() + cmd = hook._build_command(self._spark_job_file) + + # Then + expected_spark_connection = {"master": u"yarn", + "spark_binary": u"custom-spark-submit", + "deploy_mode": None, + "queue": None, + "spark_home": u"/path/to/spark_home"} + self.assertEqual(connection, expected_spark_connection) + self.assertEqual(cmd[0], '/path/to/spark_home/bin/custom-spark-submit') + def test_process_log(self): # Given hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
