Repository: incubator-airflow Updated Branches: refs/heads/master 35e43f506 -> 0ade066f4
[AIRFLOW-1085] Enhance the SparkSubmitOperator - Allow the Spark home to be set on per connection basis to obviate the need for the spark-submit to be on the PATH, and allows different versions of Spark to be easily used. - Enable the use of the --driver-memory parameter on the spark-submit by making it parameter on the operator - Enable the use of the --class parameter on the spark-submit by making it a parameter on the operator Closes #2211 from camshrun/sparkSubmitImprovements Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/0ade066f Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/0ade066f Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/0ade066f Branch: refs/heads/master Commit: 0ade066f44257c5e119b292f4cc2ba105774f4e7 Parents: 35e43f5 Author: Stephan Werges <[email protected]> Authored: Fri Apr 7 19:20:46 2017 +0200 Committer: Bolke de Bruin <[email protected]> Committed: Fri Apr 7 19:20:58 2017 +0200 ---------------------------------------------------------------------- airflow/contrib/hooks/spark_submit_hook.py | 32 ++++++++++-- .../contrib/operators/spark_submit_operator.py | 13 ++++- tests/contrib/hooks/spark_submit_hook.py | 51 +++++++++++++++++--- .../contrib/operators/spark_submit_operator.py | 8 ++- 4 files changed, 90 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0ade066f/airflow/contrib/hooks/spark_submit_hook.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py index 619cc71..59d28b5 100644 --- a/airflow/contrib/hooks/spark_submit_hook.py +++ b/airflow/contrib/hooks/spark_submit_hook.py @@ -13,6 +13,7 @@ # limitations under the License. # import logging +import os import subprocess import re @@ -25,7 +26,8 @@ log = logging.getLogger(__name__) class SparkSubmitHook(BaseHook): """ This hook is a wrapper around the spark-submit binary to kick off a spark-submit job. - It requires that the "spark-submit" binary is in the PATH. + It requires that the "spark-submit" binary is in the PATH or the spark_home to be + supplied. :param conf: Arbitrary Spark configuration properties :type conf: dict :param conn_id: The connection id as configured in Airflow administration. When an @@ -38,10 +40,14 @@ class SparkSubmitHook(BaseHook): :type py_files: str :param jars: Submit additional jars to upload and place them in executor classpath. :type jars: str + :param java_class: the main class of the Java application + :type java_class: str :param executor_cores: Number of cores per executor (Default: 2) :type executor_cores: int :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G) :type executor_memory: str + :param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) (Default: 1G) + :type driver_memory: str :param keytab: Full path to the file that contains the keytab :type keytab: str :param principal: The name of the kerberos principal used for keytab @@ -60,8 +66,10 @@ class SparkSubmitHook(BaseHook): files=None, py_files=None, jars=None, + java_class=None, executor_cores=None, executor_memory=None, + driver_memory=None, keytab=None, principal=None, name='default-name', @@ -72,8 +80,10 @@ class SparkSubmitHook(BaseHook): self._files = files self._py_files = py_files self._jars = jars + self._java_class = java_class self._executor_cores = executor_cores self._executor_memory = executor_memory + self._driver_memory = driver_memory self._keytab = keytab self._principal = principal self._name = name @@ -82,7 +92,7 @@ class SparkSubmitHook(BaseHook): self._sp = None self._yarn_application_id = None - (self._master, self._queue, self._deploy_mode) = self._resolve_connection() + (self._master, self._queue, self._deploy_mode, self._spark_home) = self._resolve_connection() self._is_yarn = 'yarn' in self._master def _resolve_connection(self): @@ -90,6 +100,7 @@ class SparkSubmitHook(BaseHook): master = 'yarn' queue = None deploy_mode = None + spark_home = None try: # Master can be local, yarn, spark://HOST:PORT or mesos://HOST:PORT @@ -105,6 +116,8 @@ class SparkSubmitHook(BaseHook): queue = extra['queue'] if 'deploy-mode' in extra: deploy_mode = extra['deploy-mode'] + if 'spark-home' in extra: + spark_home = extra['spark-home'] except AirflowException: logging.debug( "Could not load connection string {}, defaulting to {}".format( @@ -112,7 +125,7 @@ class SparkSubmitHook(BaseHook): ) ) - return master, queue, deploy_mode + return master, queue, deploy_mode, spark_home def get_conn(self): pass @@ -124,8 +137,13 @@ class SparkSubmitHook(BaseHook): :type application: str :return: full command to be executed """ - # The spark-submit binary needs to be in the path - connection_cmd = ["spark-submit"] + # If the spark_home is passed then build the spark-submit executable path using + # the spark_home; otherwise assume that spark-submit is present in the path to + # the executing user + if self._spark_home: + connection_cmd = [os.path.join(self._spark_home, 'bin', 'spark-submit')] + else: + connection_cmd = ['spark-submit'] # The url ot the spark master connection_cmd += ["--master", self._master] @@ -145,12 +163,16 @@ class SparkSubmitHook(BaseHook): connection_cmd += ["--executor-cores", str(self._executor_cores)] if self._executor_memory: connection_cmd += ["--executor-memory", self._executor_memory] + if self._driver_memory: + connection_cmd += ["--driver-memory", self._driver_memory] if self._keytab: connection_cmd += ["--keytab", self._keytab] if self._principal: connection_cmd += ["--principal", self._principal] if self._name: connection_cmd += ["--name", self._name] + if self._java_class: + connection_cmd += ["--class", self._java_class] if self._verbose: connection_cmd += ["--verbose"] if self._queue: http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0ade066f/airflow/contrib/operators/spark_submit_operator.py ---------------------------------------------------------------------- diff --git a/airflow/contrib/operators/spark_submit_operator.py b/airflow/contrib/operators/spark_submit_operator.py index a5e6145..f62c395 100644 --- a/airflow/contrib/operators/spark_submit_operator.py +++ b/airflow/contrib/operators/spark_submit_operator.py @@ -24,7 +24,8 @@ log = logging.getLogger(__name__) class SparkSubmitOperator(BaseOperator): """ This hook is a wrapper around the spark-submit binary to kick off a spark-submit job. - It requires that the "spark-submit" binary is in the PATH. + It requires that the "spark-submit" binary is in the PATH or the spark-home is set + in the extra on the connection. :param application: The application that submitted as a job, either jar or py file. :type application: str :param conf: Arbitrary Spark configuration properties @@ -39,10 +40,14 @@ class SparkSubmitOperator(BaseOperator): :type py_files: str :param jars: Submit additional jars to upload and place them in executor classpath. :type jars: str + :param java_class: the main class of the Java application + :type java_class: str :param executor_cores: Number of cores per executor (Default: 2) :type executor_cores: int :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G) :type executor_memory: str + :param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) (Default: 1G) + :type driver_memory: str :param keytab: Full path to the file that contains the keytab :type keytab: str :param principal: The name of the kerberos principal used for keytab @@ -63,8 +68,10 @@ class SparkSubmitOperator(BaseOperator): files=None, py_files=None, jars=None, + java_class=None, executor_cores=None, executor_memory=None, + driver_memory=None, keytab=None, principal=None, name='airflow-spark', @@ -78,8 +85,10 @@ class SparkSubmitOperator(BaseOperator): self._files = files self._py_files = py_files self._jars = jars + self._java_class = java_class self._executor_cores = executor_cores self._executor_memory = executor_memory + self._driver_memory = driver_memory self._keytab = keytab self._principal = principal self._name = name @@ -98,8 +107,10 @@ class SparkSubmitOperator(BaseOperator): files=self._files, py_files=self._py_files, jars=self._jars, + java_class=self._java_class, executor_cores=self._executor_cores, executor_memory=self._executor_memory, + driver_memory=self._driver_memory, keytab=self._keytab, principal=self._principal, name=self._name, http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0ade066f/tests/contrib/hooks/spark_submit_hook.py ---------------------------------------------------------------------- diff --git a/tests/contrib/hooks/spark_submit_hook.py b/tests/contrib/hooks/spark_submit_hook.py index b18925a..8f514c2 100644 --- a/tests/contrib/hooks/spark_submit_hook.py +++ b/tests/contrib/hooks/spark_submit_hook.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import unittest from airflow import configuration, models @@ -37,7 +37,9 @@ class TestSparkSubmitHook(unittest.TestCase): 'principal': 'user/[email protected]', 'name': 'spark-job', 'num_executors': 10, - 'verbose': True + 'verbose': True, + 'driver_memory': '3g', + 'java_class': 'com.foo.bar.AppMain' } def setUp(self): @@ -45,7 +47,7 @@ class TestSparkSubmitHook(unittest.TestCase): db.merge_conn( models.Connection( conn_id='spark_yarn_cluster', conn_type='spark', - host='yarn://yarn-mater', extra='{"queue": "root.etl", "deploy-mode": "cluster"}') + host='yarn://yarn-master', extra='{"queue": "root.etl", "deploy-mode": "cluster"}') ) db.merge_conn( models.Connection( @@ -53,6 +55,19 @@ class TestSparkSubmitHook(unittest.TestCase): host='mesos://host', port=5050) ) + db.merge_conn( + models.Connection( + conn_id='spark_home_set', conn_type='spark', + host='yarn://yarn-master', + extra='{"spark-home": "/opt/myspark"}') + ) + + db.merge_conn( + models.Connection( + conn_id='spark_home_not_set', conn_type='spark', + host='yarn://yarn-master') + ) + def test_build_command(self): hook = SparkSubmitHook(**self._config) @@ -72,6 +87,8 @@ class TestSparkSubmitHook(unittest.TestCase): assert "--principal {}".format(self._config['principal']) in cmd assert "--name {}".format(self._config['name']) in cmd assert "--num-executors {}".format(self._config['num_executors']) in cmd + assert "--class {}".format(self._config['java_class']) in cmd + assert "--driver-memory {}".format(self._config['driver_memory']) in cmd # Check if all config settings are there for k in self._config['conf']: @@ -92,14 +109,14 @@ class TestSparkSubmitHook(unittest.TestCase): # Default to the standard yarn connection because conn_id does not exists hook = SparkSubmitHook(conn_id='') - self.assertEqual(hook._resolve_connection(), ('yarn', None, None)) + self.assertEqual(hook._resolve_connection(), ('yarn', None, None, None)) assert "--master yarn" in ' '.join(hook._build_command(self._spark_job_file)) # Default to the standard yarn connection hook = SparkSubmitHook(conn_id='spark_default') self.assertEqual( hook._resolve_connection(), - ('yarn', 'root.default', None) + ('yarn', 'root.default', None, None) ) cmd = ' '.join(hook._build_command(self._spark_job_file)) assert "--master yarn" in cmd @@ -109,7 +126,7 @@ class TestSparkSubmitHook(unittest.TestCase): hook = SparkSubmitHook(conn_id='spark_default_mesos') self.assertEqual( hook._resolve_connection(), - ('mesos://host:5050', None, None) + ('mesos://host:5050', None, None, None) ) cmd = ' '.join(hook._build_command(self._spark_job_file)) @@ -119,7 +136,7 @@ class TestSparkSubmitHook(unittest.TestCase): hook = SparkSubmitHook(conn_id='spark_yarn_cluster') self.assertEqual( hook._resolve_connection(), - ('yarn://yarn-master', 'root.etl', 'cluster') + ('yarn://yarn-master', 'root.etl', 'cluster', None) ) cmd = ' '.join(hook._build_command(self._spark_job_file)) @@ -127,6 +144,26 @@ class TestSparkSubmitHook(unittest.TestCase): assert "--queue root.etl" in cmd assert "--deploy-mode cluster" in cmd + # Set the spark home + hook = SparkSubmitHook(conn_id='spark_home_set') + self.assertEqual( + hook._resolve_connection(), + ('yarn://yarn-master', None, None, '/opt/myspark') + ) + + cmd = ' '.join(hook._build_command(self._spark_job_file)) + assert cmd.startswith('/opt/myspark/bin/spark-submit') + + # Spark home not set + hook = SparkSubmitHook(conn_id='spark_home_not_set') + self.assertEqual( + hook._resolve_connection(), + ('yarn://yarn-master', None, None, None) + ) + + cmd = ' '.join(hook._build_command(self._spark_job_file)) + assert cmd.startswith('spark-submit') + def test_process_log(self): # Must select yarn connection hook = SparkSubmitHook(conn_id='spark_yarn_cluster') http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/0ade066f/tests/contrib/operators/spark_submit_operator.py ---------------------------------------------------------------------- diff --git a/tests/contrib/operators/spark_submit_operator.py b/tests/contrib/operators/spark_submit_operator.py index c080f76..4e2afb2 100644 --- a/tests/contrib/operators/spark_submit_operator.py +++ b/tests/contrib/operators/spark_submit_operator.py @@ -37,7 +37,9 @@ class TestSparkSubmitOperator(unittest.TestCase): 'name': 'spark-job', 'num_executors': 10, 'verbose': True, - 'application': 'test_application.py' + 'application': 'test_application.py', + 'driver_memory': '3g', + 'java_class': 'com.foo.bar.AppMain' } def setUp(self): @@ -69,6 +71,10 @@ class TestSparkSubmitOperator(unittest.TestCase): self.assertEqual(self._config['name'], operator._name) self.assertEqual(self._config['num_executors'], operator._num_executors) self.assertEqual(self._config['verbose'], operator._verbose) + self.assertEqual(self._config['java_class'], operator._java_class) + self.assertEqual(self._config['driver_memory'], operator._driver_memory) + + if __name__ == '__main__':
