Repository: incubator-airflow Updated Branches: refs/heads/master 22453d037 -> 3e6babe8e
[AIRFLOW-1854] Improve Spark Submit operator for standalone cluster mode Closes #2852 from milanvdmria/svend/submit2 Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/3e6babe8 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/3e6babe8 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/3e6babe8 Branch: refs/heads/master Commit: 3e6babe8ed8f8f281b67aa3f4e03bf3cfc1bcbaa Parents: 22453d0 Author: milanvdmria <[email protected]> Authored: Tue Dec 12 12:45:41 2017 +0100 Committer: Bolke de Bruin <[email protected]> Committed: Tue Dec 12 12:45:52 2017 +0100 ---------------------------------------------------------------------- airflow/contrib/hooks/spark_submit_hook.py | 217 ++++++++++++++++--- .../contrib/operators/spark_submit_operator.py | 17 +- tests/contrib/hooks/test_spark_submit_hook.py | 175 ++++++++++++--- .../operators/test_spark_submit_operator.py | 19 +- 4 files changed, 354 insertions(+), 74 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3e6babe8/airflow/contrib/hooks/spark_submit_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py index c0bc84f..16e14b4 100644 --- a/airflow/contrib/hooks/spark_submit_hook.py +++ b/airflow/contrib/hooks/spark_submit_hook.py @@ -15,6 +15,7 @@ import os import subprocess import re +import time from airflow.hooks.base_hook import BaseHook from airflow.exceptions import AirflowException @@ -42,15 +43,20 @@ class SparkSubmitHook(BaseHook, LoggingMixin): :type jars: str :param java_class: the main class of the Java application :type java_class: str - :param packages: Comma-separated list of maven coordinates of jars to include on the driver and executor classpaths + :param packages: Comma-separated list of maven coordinates of jars to include on the + driver and executor classpaths :type packages: str - :param exclude_packages: Comma-separated list of maven coordinates of jars to exclude while resolving the dependencies provided in 'packages' + :param exclude_packages: Comma-separated list of maven coordinates of jars to exclude + while resolving the dependencies provided in 'packages' :type exclude_packages: str - :param repositories: Comma-separated list of additional remote repositories to search for the maven coordinates given with 'packages' + :param repositories: Comma-separated list of additional remote repositories to search + for the maven coordinates given with 'packages' :type repositories: str - :param total_executor_cores: (Standalone & Mesos only) Total cores for all executors (Default: all the available cores on the worker) + :param total_executor_cores: (Standalone & Mesos only) Total cores for all executors + (Default: all the available cores on the worker) :type total_executor_cores: int - :param executor_cores: (Standalone & YARN only) Number of cores per executor (Default: 2) + :param executor_cores: (Standalone & YARN only) Number of cores per executor + (Default: 2) :type executor_cores: int :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G) :type executor_memory: str @@ -110,12 +116,25 @@ class SparkSubmitHook(BaseHook, LoggingMixin): self._num_executors = num_executors self._application_args = application_args self._verbose = verbose - self._sp = None + self._submit_sp = None self._yarn_application_id = None self._connection = self._resolve_connection() self._is_yarn = 'yarn' in self._connection['master'] + self._should_track_driver_status = self._resolve_should_track_driver_status() + self._driver_id = None + self._driver_status = None + + def _resolve_should_track_driver_status(self): + """ + Determines whether or not this hook should poll the spark driver status through + subsequent spark-submit status requests after the initial spark-submit request + :return: if the driver status should be tracked + """ + return ('spark://' in self._connection['master'] and + self._connection['deploy_mode'] == 'cluster') + def _resolve_connection(self): # Build from connection master or default to yarn if not available conn_data = {'master': 'yarn', @@ -149,21 +168,27 @@ class SparkSubmitHook(BaseHook, LoggingMixin): def get_conn(self): pass - def _build_command(self, application): - """ - Construct the spark-submit command to execute. - :param application: command to append to the spark-submit command - :type application: str - :return: full command to be executed - """ + def _get_spark_binary_path(self): # If the spark_home is passed then build the spark-submit executable path using # the spark_home; otherwise assume that spark-submit is present in the path to # the executing user if self._connection['spark_home']: - connection_cmd = [os.path.join(self._connection['spark_home'], 'bin', self._connection['spark_binary'])] + connection_cmd = [os.path.join(self._connection['spark_home'], 'bin', + self._connection['spark_binary'])] else: connection_cmd = [self._connection['spark_binary']] + return connection_cmd + + def _build_spark_submit_command(self, application): + """ + Construct the spark-submit command to execute. + :param application: command to append to the spark-submit command + :type application: str + :return: full command to be executed + """ + connection_cmd = self._get_spark_binary_path() + # The url ot the spark master connection_cmd += ["--master", self._connection['master']] @@ -216,7 +241,30 @@ class SparkSubmitHook(BaseHook, LoggingMixin): if self._application_args: connection_cmd += self._application_args - self.log.debug("Spark-Submit cmd: %s", connection_cmd) + self.log.info("Spark-Submit cmd: %s", connection_cmd) + + return connection_cmd + + def _build_track_driver_status_command(self): + """ + Construct the command to poll the driver status. + + :return: full command to be executed + """ + connection_cmd = self._get_spark_binary_path() + + # The url ot the spark master + connection_cmd += ["--master", self._connection['master']] + + # The driver id so we can poll for its status + if self._driver_id: + connection_cmd += ["--status", self._driver_id] + else: + raise AirflowException( + "Invalid status: attempted to poll driver " + + "status but no driver id is known. Giving up.") + + self.log.debug("Poll driver status cmd: %s", connection_cmd) return connection_cmd @@ -228,16 +276,16 @@ class SparkSubmitHook(BaseHook, LoggingMixin): :type application: str :param kwargs: extra arguments to Popen (see subprocess.Popen) """ - spark_submit_cmd = self._build_command(application) - self._sp = subprocess.Popen(spark_submit_cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - bufsize=-1, - universal_newlines=True, - **kwargs) + spark_submit_cmd = self._build_spark_submit_command(application) + self._submit_sp = subprocess.Popen(spark_submit_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=-1, + universal_newlines=True, + **kwargs) - self._process_log(iter(self._sp.stdout.readline, '')) - returncode = self._sp.wait() + self._process_spark_submit_log(iter(self._submit_sp.stdout.readline, '')) + returncode = self._submit_sp.wait() if returncode: raise AirflowException( @@ -246,9 +294,34 @@ class SparkSubmitHook(BaseHook, LoggingMixin): ) ) - def _process_log(self, itr): + self.log.debug("Should track driver: {}".format(self._should_track_driver_status)) + + # We want the Airflow job to wait until the Spark driver is finished + if self._should_track_driver_status: + if self._driver_id is None: + raise AirflowException( + "No driver id is known: something went wrong when executing " + + "the spark submit command" + ) + + # We start with the SUBMITTED status as initial status + self._driver_status = "SUBMITTED" + + # Start tracking the driver status (blocking function) + self._start_driver_status_tracking() + + if self._driver_status != "FINISHED": + raise AirflowException( + "ERROR : Driver {} badly exited with status {}" + .format(self._driver_id, self._driver_status) + ) + + def _process_spark_submit_log(self, itr): """ - Processes the log files and extracts useful information out of it + Processes the log files and extracts useful information out of it. + + Remark: If the driver needs to be tracked for its status, the log-level of the + spark deploy needs to be at least INFO (log4j.logger.org.apache.spark.deploy=INFO) :param itr: An iterator which iterates over the input of the subprocess """ @@ -262,16 +335,94 @@ class SparkSubmitHook(BaseHook, LoggingMixin): if match: self._yarn_application_id = match.groups()[0] - # Pass to logging - self.log.info(line) + # if we run in standalone cluster mode and we want to track the driver status + # we need to extract the driver id from the logs. This allows us to poll for + # the status using the driver id. Also, we can kill the driver when needed. + if self._should_track_driver_status and not self._driver_id: + match_driver_id = re.search('(driver-[0-9\-]+)', line) + if match_driver_id: + self._driver_id = match_driver_id.groups()[0] + self.log.info("identified spark driver id: {}" + .format(self._driver_id)) + + self.log.debug("spark submit log: {}".format(line)) + + def _process_spark_status_log(self, itr): + """ + parses the logs of the spark driver status query process + + :param itr: An iterator which iterates over the input of the subprocess + """ + # Consume the iterator + for line in itr: + line = line.strip() + + # Check if the log line is about the driver status and extract the status. + if "driverState" in line: + self._driver_status = line.split(' : ')[1] \ + .replace(',', '').replace('\"', '').strip() + + self.log.debug("spark driver status log: {}".format(line)) + + def _start_driver_status_tracking(self): + """ + Polls the driver based on self._driver_id to get the status. + Finish successfully when the status is FINISHED. + Finish failed when the status is ERROR/UNKNOWN/KILLED/FAILED. + + Possible status: + SUBMITTED: Submitted but not yet scheduled on a worker + RUNNING: Has been allocated to a worker to run + FINISHED: Previously ran and exited cleanly + RELAUNCHING: Exited non-zero or due to worker failure, but has not yet + started running again + UNKNOWN: The status of the driver is temporarily not known due to + master failure recovery + KILLED: A user manually killed this driver + FAILED: The driver exited non-zero and was not supervised + ERROR: Unable to run or restart due to an unrecoverable error + (e.g. missing jar file) + """ + # Keep polling as long as the driver is processing + while self._driver_status not in ["FINISHED", "UNKNOWN", + "KILLED", "FAILED", "ERROR"]: + + # Sleep for 1 second as we do not want to spam the cluster + time.sleep(1) + + self.log.debug("polling status of spark driver with id {}" + .format(self._driver_id)) + + poll_drive_status_cmd = self._build_track_driver_status_command() + status_process = subprocess.Popen(poll_drive_status_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=-1, + universal_newlines=True) + + self._process_spark_status_log(iter(status_process.stdout.readline, '')) + returncode = status_process.wait() + + if returncode: + raise AirflowException( + "Failed to poll for the driver status: returncode = {}" + .format(returncode) + ) def on_kill(self): - if self._sp and self._sp.poll() is None: + + if self._submit_sp and self._submit_sp.poll() is None: self.log.info('Sending kill signal to %s', self._connection['spark_binary']) - self._sp.kill() + self._submit_sp.kill() if self._yarn_application_id: - self.log.info('Killing application on YARN') - kill_cmd = "yarn application -kill {0}".format(self._yarn_application_id).split() - yarn_kill = subprocess.Popen(kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self.log.info('Killing application {} on YARN' + .format(self._yarn_application_id)) + + kill_cmd = "yarn application -kill {}" \ + .format(self._yarn_application_id).split() + yarn_kill = subprocess.Popen(kill_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + self.log.info("YARN killed with return code: %s", yarn_kill.wait()) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3e6babe8/airflow/contrib/operators/spark_submit_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/spark_submit_operator.py b/airflow/contrib/operators/spark_submit_operator.py index ae821fa..d743393 100644 --- a/airflow/contrib/operators/spark_submit_operator.py +++ b/airflow/contrib/operators/spark_submit_operator.py @@ -42,15 +42,20 @@ class SparkSubmitOperator(BaseOperator): :type jars: str :param java_class: the main class of the Java application :type java_class: str - :param packages: Comma-separated list of maven coordinates of jars to include on the driver and executor classpaths + :param packages: Comma-separated list of maven coordinates of jars to include on the + driver and executor classpaths :type packages: str - :param exclude_packages: Comma-separated list of maven coordinates of jars to exclude while resolving the dependencies provided in 'packages' + :param exclude_packages: Comma-separated list of maven coordinates of jars to exclude + while resolving the dependencies provided in 'packages' :type exclude_packages: str - :param repositories: Comma-separated list of additional remote repositories to search for the maven coordinates given with 'packages' + :param repositories: Comma-separated list of additional remote repositories to search + for the maven coordinates given with 'packages' :type repositories: str - :param total_executor_cores: (Standalone & Mesos only) Total cores for all executors (Default: all the available cores on the worker) + :param total_executor_cores: (Standalone & Mesos only) Total cores for all executors + (Default: all the available cores on the worker) :type total_executor_cores: int - :param executor_cores: (Standalone & YARN only) Number of cores per executor (Default: 2) + :param executor_cores: (Standalone & YARN only) Number of cores per executor + (Default: 2) :type executor_cores: int :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G) :type executor_memory: str @@ -69,7 +74,7 @@ class SparkSubmitOperator(BaseOperator): :param verbose: Whether to pass the verbose flag to spark-submit process for debugging :type verbose: bool """ - template_fields = ('_name', '_application_args','_packages') + template_fields = ('_name', '_application_args', '_packages') ui_color = WEB_COLORS['LIGHTORANGE'] @apply_defaults http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3e6babe8/tests/contrib/hooks/test_spark_submit_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/test_spark_submit_hook.py b/tests/contrib/hooks/test_spark_submit_hook.py index 5cb7132..6c55ce2 100644 --- a/tests/contrib/hooks/test_spark_submit_hook.py +++ b/tests/contrib/hooks/test_spark_submit_hook.py @@ -13,7 +13,6 @@ # limitations under the License. # import six -import sys import unittest from airflow import configuration, models @@ -61,7 +60,7 @@ class TestSparkSubmitHook(unittest.TestCase): for arg in list_cmd: if arg.startswith("--"): pos = list_cmd.index(arg) - return_dict[arg] = list_cmd[pos+1] + return_dict[arg] = list_cmd[pos + 1] return return_dict def setUp(self): @@ -70,7 +69,8 @@ class TestSparkSubmitHook(unittest.TestCase): db.merge_conn( models.Connection( conn_id='spark_yarn_cluster', conn_type='spark', - host='yarn://yarn-master', extra='{"queue": "root.etl", "deploy-mode": "cluster"}') + host='yarn://yarn-master', + extra='{"queue": "root.etl", "deploy-mode": "cluster"}') ) db.merge_conn( models.Connection( @@ -98,15 +98,23 @@ class TestSparkSubmitHook(unittest.TestCase): db.merge_conn( models.Connection( conn_id='spark_binary_and_home_set', conn_type='spark', - host='yarn', extra='{"spark-home": "/path/to/spark_home", "spark-binary": "custom-spark-submit"}') + host='yarn', + extra='{"spark-home": "/path/to/spark_home", ' + + '"spark-binary": "custom-spark-submit"}') + ) + db.merge_conn( + models.Connection( + conn_id='spark_standalone_cluster', conn_type='spark', + host='spark://spark-standalone-master:6066', + extra='{"spark-home": "/path/to/spark_home", "deploy-mode": "cluster"}') ) - def test_build_command(self): + def test_build_spark_submit_command(self): # Given hook = SparkSubmitHook(**self._config) # When - cmd = hook._build_command(self._spark_job_file) + cmd = hook._build_spark_submit_command(self._spark_job_file) # Then expected_build_cmd = [ @@ -149,7 +157,51 @@ class TestSparkSubmitHook(unittest.TestCase): hook.submit() # Then - self.assertEqual(mock_popen.mock_calls[0], call(['spark-submit', '--master', 'yarn', '--name', 'default-name', ''], stderr=-2, stdout=-1, universal_newlines=True, bufsize=-1)) + self.assertEqual(mock_popen.mock_calls[0], + call(['spark-submit', '--master', 'yarn', + '--name', 'default-name', ''], + stderr=-2, stdout=-1, universal_newlines=True, bufsize=-1)) + + def test_resolve_should_track_driver_status(self): + # Given + hook_default = SparkSubmitHook(conn_id='') + hook_spark_yarn_cluster = SparkSubmitHook(conn_id='spark_yarn_cluster') + hook_spark_default_mesos = SparkSubmitHook(conn_id='spark_default_mesos') + hook_spark_home_set = SparkSubmitHook(conn_id='spark_home_set') + hook_spark_home_not_set = SparkSubmitHook(conn_id='spark_home_not_set') + hook_spark_binary_set = SparkSubmitHook(conn_id='spark_binary_set') + hook_spark_binary_and_home_set = SparkSubmitHook( + conn_id='spark_binary_and_home_set') + hook_spark_standalone_cluster = SparkSubmitHook( + conn_id='spark_standalone_cluster') + + # When + should_track_driver_status_default = hook_default \ + ._resolve_should_track_driver_status() + should_track_driver_status_spark_yarn_cluster = hook_spark_yarn_cluster \ + ._resolve_should_track_driver_status() + should_track_driver_status_spark_default_mesos = hook_spark_default_mesos \ + ._resolve_should_track_driver_status() + should_track_driver_status_spark_home_set = hook_spark_home_set \ + ._resolve_should_track_driver_status() + should_track_driver_status_spark_home_not_set = hook_spark_home_not_set \ + ._resolve_should_track_driver_status() + should_track_driver_status_spark_binary_set = hook_spark_binary_set \ + ._resolve_should_track_driver_status() + should_track_driver_status_spark_binary_and_home_set = \ + hook_spark_binary_and_home_set._resolve_should_track_driver_status() + should_track_driver_status_spark_standalone_cluster = \ + hook_spark_standalone_cluster._resolve_should_track_driver_status() + + # Then + self.assertEqual(should_track_driver_status_default, False) + self.assertEqual(should_track_driver_status_spark_yarn_cluster, False) + self.assertEqual(should_track_driver_status_spark_default_mesos, False) + self.assertEqual(should_track_driver_status_spark_home_set, False) + self.assertEqual(should_track_driver_status_spark_home_not_set, False) + self.assertEqual(should_track_driver_status_spark_binary_set, False) + self.assertEqual(should_track_driver_status_spark_binary_and_home_set, False) + self.assertEqual(should_track_driver_status_spark_standalone_cluster, True) def test_resolve_connection_yarn_default(self): # Given @@ -157,7 +209,7 @@ class TestSparkSubmitHook(unittest.TestCase): # When connection = hook._resolve_connection() - cmd = hook._build_command(self._spark_job_file) + cmd = hook._build_spark_submit_command(self._spark_job_file) # Then dict_cmd = self.cmd_args_to_dict(cmd) @@ -175,7 +227,7 @@ class TestSparkSubmitHook(unittest.TestCase): # When connection = hook._resolve_connection() - cmd = hook._build_command(self._spark_job_file) + cmd = hook._build_spark_submit_command(self._spark_job_file) # Then dict_cmd = self.cmd_args_to_dict(cmd) @@ -194,7 +246,7 @@ class TestSparkSubmitHook(unittest.TestCase): # When connection = hook._resolve_connection() - cmd = hook._build_command(self._spark_job_file) + cmd = hook._build_spark_submit_command(self._spark_job_file) # Then dict_cmd = self.cmd_args_to_dict(cmd) @@ -212,7 +264,7 @@ class TestSparkSubmitHook(unittest.TestCase): # When connection = hook._resolve_connection() - cmd = hook._build_command(self._spark_job_file) + cmd = hook._build_spark_submit_command(self._spark_job_file) # Then dict_cmd = self.cmd_args_to_dict(cmd) @@ -232,7 +284,7 @@ class TestSparkSubmitHook(unittest.TestCase): # When connection = hook._resolve_connection() - cmd = hook._build_command(self._spark_job_file) + cmd = hook._build_spark_submit_command(self._spark_job_file) # Then expected_spark_connection = {"master": "yarn://yarn-master", @@ -249,7 +301,7 @@ class TestSparkSubmitHook(unittest.TestCase): # When connection = hook._resolve_connection() - cmd = hook._build_command(self._spark_job_file) + cmd = hook._build_spark_submit_command(self._spark_job_file) # Then expected_spark_connection = {"master": "yarn://yarn-master", @@ -266,7 +318,7 @@ class TestSparkSubmitHook(unittest.TestCase): # When connection = hook._resolve_connection() - cmd = hook._build_command(self._spark_job_file) + cmd = hook._build_spark_submit_command(self._spark_job_file) # Then expected_spark_connection = {"master": "yarn", @@ -283,7 +335,7 @@ class TestSparkSubmitHook(unittest.TestCase): # When connection = hook._resolve_connection() - cmd = hook._build_command(self._spark_job_file) + cmd = hook._build_spark_submit_command(self._spark_job_file) # Then expected_spark_connection = {"master": "yarn", @@ -294,25 +346,87 @@ class TestSparkSubmitHook(unittest.TestCase): self.assertEqual(connection, expected_spark_connection) self.assertEqual(cmd[0], '/path/to/spark_home/bin/custom-spark-submit') - def test_process_log(self): + def test_resolve_connection_spark_standalone_cluster_connection(self): + # Given + hook = SparkSubmitHook(conn_id='spark_standalone_cluster') + + # When + connection = hook._resolve_connection() + cmd = hook._build_spark_submit_command(self._spark_job_file) + + # Then + expected_spark_connection = {"master": "spark://spark-standalone-master:6066", + "spark_binary": "spark-submit", + "deploy_mode": "cluster", + "queue": None, + "spark_home": "/path/to/spark_home"} + self.assertEqual(connection, expected_spark_connection) + self.assertEqual(cmd[0], '/path/to/spark_home/bin/spark-submit') + + def test_process_spark_submit_log_yarn(self): # Given hook = SparkSubmitHook(conn_id='spark_yarn_cluster') log_lines = [ 'SPARK_MAJOR_VERSION is set to 2, using Spark2', - 'WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable', - 'WARN DomainSocketFactory: The short-circuit local reads feature cannot be used because libhadoop cannot be loaded.', + 'WARN NativeCodeLoader: Unable to load native-hadoop library for your ' + + 'platform... using builtin-java classes where applicable', + 'WARN DomainSocketFactory: The short-circuit local reads feature cannot ' + 'be used because libhadoop cannot be loaded.', 'INFO Client: Requesting a new application from cluster with 10 NodeManagers', - 'INFO Client: Submitting application application_1486558679801_1820 to ResourceManager' + 'INFO Client: Submitting application application_1486558679801_1820 ' + + 'to ResourceManager' ] # When - hook._process_log(log_lines) + hook._process_spark_submit_log(log_lines) # Then self.assertEqual(hook._yarn_application_id, 'application_1486558679801_1820') + def test_process_spark_submit_log_standalone_cluster(self): + # Given + hook = SparkSubmitHook(conn_id='spark_standalone_cluster') + log_lines = [ + 'Running Spark using the REST application submission protocol.', + '17/11/28 11:14:15 INFO RestSubmissionClient: Submitting a request ' + 'to launch an application in spark://spark-standalone-master:6066', + '17/11/28 11:14:15 INFO RestSubmissionClient: Submission successfully ' + + 'created as driver-20171128111415-0001. Polling submission state...' + ] + # When + hook._process_spark_submit_log(log_lines) + + # Then + + self.assertEqual(hook._driver_id, 'driver-20171128111415-0001') + + def test_process_spark_driver_status_log(self): + # Given + hook = SparkSubmitHook(conn_id='spark_standalone_cluster') + log_lines = [ + 'Submitting a request for the status of submission ' + + 'driver-20171128111415-0001 in spark://spark-standalone-master:6066', + '17/11/28 11:15:37 INFO RestSubmissionClient: Server responded with ' + + 'SubmissionStatusResponse:', + '{', + '"action" : "SubmissionStatusResponse",', + '"driverState" : "RUNNING",', + '"serverSparkVersion" : "1.6.0",', + '"submissionId" : "driver-20171128111415-0001",', + '"success" : true,', + '"workerHostPort" : "172.18.0.7:38561",', + '"workerId" : "worker-20171128110741-172.18.0.7-38561"', + '}' + ] + # When + hook._process_spark_status_log(log_lines) + + # Then + + self.assertEqual(hook._driver_status, 'RUNNING') + @patch('airflow.contrib.hooks.spark_submit_hook.subprocess.Popen') - def test_spark_process_on_kill(self, mock_popen): + def test_yarn_process_on_kill(self, mock_popen): # Given mock_popen.return_value.stdout = six.StringIO('stdout') mock_popen.return_value.stderr = six.StringIO('stderr') @@ -320,20 +434,27 @@ class TestSparkSubmitHook(unittest.TestCase): mock_popen.return_value.wait.return_value = 0 log_lines = [ 'SPARK_MAJOR_VERSION is set to 2, using Spark2', - 'WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable', - 'WARN DomainSocketFactory: The short-circuit local reads feature cannot be used because libhadoop cannot be loaded.', - 'INFO Client: Requesting a new application from cluster with 10 NodeManagerapplication_1486558679801_1820s', - 'INFO Client: Submitting application application_1486558679801_1820 to ResourceManager' + 'WARN NativeCodeLoader: Unable to load native-hadoop library for your ' + + 'platform... using builtin-java classes where applicable', + 'WARN DomainSocketFactory: The short-circuit local reads feature cannot ' + + 'be used because libhadoop cannot be loaded.', + 'INFO Client: Requesting a new application from cluster with 10 ' + + 'NodeManagerapplication_1486558679801_1820s', + 'INFO Client: Submitting application application_1486558679801_1820 ' + + 'to ResourceManager' ] hook = SparkSubmitHook(conn_id='spark_yarn_cluster') - hook._process_log(log_lines) + hook._process_spark_submit_log(log_lines) hook.submit() # When hook.on_kill() # Then - self.assertIn(call(['yarn', 'application', '-kill', 'application_1486558679801_1820'], stderr=-1, stdout=-1), mock_popen.mock_calls) + self.assertIn(call(['yarn', 'application', '-kill', + 'application_1486558679801_1820'], + stderr=-1, stdout=-1), + mock_popen.mock_calls) if __name__ == '__main__': http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3e6babe8/tests/contrib/operators/test_spark_submit_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/test_spark_submit_operator.py b/tests/contrib/operators/test_spark_submit_operator.py index 05ddc32..5903fcd 100644 --- a/tests/contrib/operators/test_spark_submit_operator.py +++ b/tests/contrib/operators/test_spark_submit_operator.py @@ -14,7 +14,6 @@ # import unittest -import sys from airflow import DAG, configuration from airflow.models import TaskInstance @@ -40,7 +39,7 @@ class TestSparkSubmitOperator(unittest.TestCase): 'packages': 'com.databricks:spark-avro_2.11:3.2.0', 'exclude_packages': 'org.bad.dependency:1.0.0', 'repositories': 'http://myrepo.org', - 'total_executor_cores':4, + 'total_executor_cores': 4, 'executor_cores': 4, 'executor_memory': '22g', 'keytab': 'privileged_user.keytab', @@ -107,7 +106,6 @@ class TestSparkSubmitOperator(unittest.TestCase): '--end', '{{ ds }}', '--with-spaces', 'args should keep embdedded spaces', ] - } self.assertEqual(conn_id, operator._conn_id) @@ -120,7 +118,8 @@ class TestSparkSubmitOperator(unittest.TestCase): self.assertEqual(expected_dict['packages'], operator._packages) self.assertEqual(expected_dict['exclude_packages'], operator._exclude_packages) self.assertEqual(expected_dict['repositories'], operator._repositories) - self.assertEqual(expected_dict['total_executor_cores'], operator._total_executor_cores) + self.assertEqual(expected_dict['total_executor_cores'], + operator._total_executor_cores) self.assertEqual(expected_dict['executor_cores'], operator._executor_cores) self.assertEqual(expected_dict['executor_memory'], operator._executor_memory) self.assertEqual(expected_dict['keytab'], operator._keytab) @@ -134,7 +133,8 @@ class TestSparkSubmitOperator(unittest.TestCase): def test_render_template(self): # Given - operator = SparkSubmitOperator(task_id='spark_submit_job', dag=self.dag, **self._config) + operator = SparkSubmitOperator(task_id='spark_submit_job', + dag=self.dag, **self._config) ti = TaskInstance(operator, DEFAULT_DATE) # When @@ -143,12 +143,15 @@ class TestSparkSubmitOperator(unittest.TestCase): # Then expected_application_args = [u'-f', 'foo', u'--bar', 'bar', - u'--start', (DEFAULT_DATE - timedelta(days=1)).strftime("%Y-%m-%d"), + u'--start', (DEFAULT_DATE - timedelta(days=1)) + .strftime("%Y-%m-%d"), u'--end', DEFAULT_DATE.strftime("%Y-%m-%d"), - u'--with-spaces', u'args should keep embdedded spaces', + u'--with-spaces', + u'args should keep embdedded spaces', ] expected_name = "spark_submit_job" - self.assertListEqual(expected_application_args, getattr(operator, '_application_args')) + self.assertListEqual(expected_application_args, + getattr(operator, '_application_args')) self.assertEqual(expected_name, getattr(operator, '_name'))
