Repository: incubator-airflow
Updated Branches:
  refs/heads/master b0d0d0a04 -> 36193fc74


[AIRFLOW-2380] Add support for environment variables in Spark submit operator.

Closes #3268 from piffall/master


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/36193fc7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/36193fc7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/36193fc7

Branch: refs/heads/master
Commit: 36193fc7449ca67c807b54ad17a086b35c0c4471
Parents: b0d0d0a
Author: Cristòfol Torrens <tofol.torr...@bluekiri.com>
Authored: Thu Apr 26 14:21:21 2018 -0700
Committer: Arthur Wiedmer <art...@apache.org>
Committed: Thu Apr 26 14:21:21 2018 -0700

----------------------------------------------------------------------
 airflow/contrib/hooks/spark_submit_hook.py      | 29 +++++++++-
 .../contrib/operators/spark_submit_operator.py  | 10 +++-
 tests/contrib/hooks/test_spark_submit_hook.py   | 59 +++++++++++++++++++-
 3 files changed, 91 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/36193fc7/airflow/contrib/hooks/spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/spark_submit_hook.py 
b/airflow/contrib/hooks/spark_submit_hook.py
index 71c68c0..0185cab 100644
--- a/airflow/contrib/hooks/spark_submit_hook.py
+++ b/airflow/contrib/hooks/spark_submit_hook.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -80,6 +80,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
     :type num_executors: int
     :param application_args: Arguments for the application being submitted
     :type application_args: list
+    :param env_vars: Environment variables for spark-submit. It
+                     supports yarn and k8s mode too.
+    :type env_vars: dict
     :param verbose: Whether to pass the verbose flag to spark-submit process 
for debugging
     :type verbose: bool
     """
@@ -103,6 +106,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                  name='default-name',
                  num_executors=None,
                  application_args=None,
+                 env_vars=None,
                  verbose=False):
         self._conf = conf
         self._conn_id = conn_id
@@ -123,6 +127,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         self._name = name
         self._num_executors = num_executors
         self._application_args = application_args
+        self._env_vars = env_vars
         self._verbose = verbose
         self._submit_sp = None
         self._yarn_application_id = None
@@ -209,6 +214,20 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         if self._conf:
             for key in self._conf:
                 connection_cmd += ["--conf", "{}={}".format(key, 
str(self._conf[key]))]
+        if self._env_vars and (self._is_kubernetes or self._is_yarn):
+            if self._is_yarn:
+                tmpl = "spark.yarn.appMasterEnv.{}={}"
+            else:
+                tmpl = "spark.kubernetes.driverEnv.{}={}"
+            for key in self._env_vars:
+                connection_cmd += [
+                    "--conf",
+                    tmpl.format(key, str(self._env_vars[key]))]
+        elif self._env_vars and self._connection['deploy_mode'] != "cluster":
+            self._env = self._env_vars  # Do it on Popen of the process
+        elif self._env_vars and self._connection['deploy_mode'] == "cluster":
+            raise AirflowException(
+                "SparkSubmitHook env_vars is not supported in 
standalone-cluster mode.")
         if self._is_kubernetes:
             connection_cmd += ["--conf", 
"spark.kubernetes.namespace={}".format(
                 self._connection['namespace'])]
@@ -294,6 +313,12 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         :param kwargs: extra arguments to Popen (see subprocess.Popen)
         """
         spark_submit_cmd = self._build_spark_submit_command(application)
+
+        if hasattr(self, '_env'):
+            env = os.environ.copy()
+            env.update(self._env)
+            kwargs["env"] = env
+
         self._submit_sp = subprocess.Popen(spark_submit_cmd,
                                            stdout=subprocess.PIPE,
                                            stderr=subprocess.STDOUT,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/36193fc7/airflow/contrib/operators/spark_submit_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/spark_submit_operator.py 
b/airflow/contrib/operators/spark_submit_operator.py
index 61a5d6a..27bd61b 100644
--- a/airflow/contrib/operators/spark_submit_operator.py
+++ b/airflow/contrib/operators/spark_submit_operator.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -77,6 +77,9 @@ class SparkSubmitOperator(BaseOperator):
     :type num_executors: int
     :param application_args: Arguments for the application being submitted
     :type application_args: list
+    :param env_vars: Environment variables for spark-submit. It
+                     supports yarn and k8s mode too.
+    :type env_vars: dict
     :param verbose: Whether to pass the verbose flag to spark-submit process 
for debugging
     :type verbose: bool
     """
@@ -105,6 +108,7 @@ class SparkSubmitOperator(BaseOperator):
                  name='airflow-spark',
                  num_executors=None,
                  application_args=None,
+                 env_vars=None,
                  verbose=False,
                  *args,
                  **kwargs):
@@ -128,6 +132,7 @@ class SparkSubmitOperator(BaseOperator):
         self._name = name
         self._num_executors = num_executors
         self._application_args = application_args
+        self._env_vars = env_vars
         self._verbose = verbose
         self._hook = None
         self._conn_id = conn_id
@@ -156,6 +161,7 @@ class SparkSubmitOperator(BaseOperator):
             name=self._name,
             num_executors=self._num_executors,
             application_args=self._application_args,
+            env_vars=self._env_vars,
             verbose=self._verbose
         )
         self._hook.submit(self._application)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/36193fc7/tests/contrib/hooks/test_spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_spark_submit_hook.py 
b/tests/contrib/hooks/test_spark_submit_hook.py
index c0cfff7..1bdcda5 100644
--- a/tests/contrib/hooks/test_spark_submit_hook.py
+++ b/tests/contrib/hooks/test_spark_submit_hook.py
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -20,7 +20,7 @@
 import six
 import unittest
 
-from airflow import configuration, models
+from airflow import configuration, models, AirflowException
 from airflow.utils import db
 from mock import patch, call
 
@@ -121,6 +121,12 @@ class TestSparkSubmitHook(unittest.TestCase):
                 host='spark://spark-standalone-master:6066',
                 extra='{"spark-home": "/path/to/spark_home", "deploy-mode": 
"cluster"}')
         )
+        db.merge_conn(
+            models.Connection(
+                conn_id='spark_standalone_cluster_client_mode', 
conn_type='spark',
+                host='spark://spark-standalone-master:6066',
+                extra='{"spark-home": "/path/to/spark_home", "deploy-mode": 
"client"}')
+        )
 
     def test_build_spark_submit_command(self):
         # Given
@@ -409,6 +415,53 @@ class TestSparkSubmitHook(unittest.TestCase):
         self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(cmd[0], '/path/to/spark_home/bin/spark-submit')
 
+    def test_resolve_spark_submit_env_vars_standalone_client_mode(self):
+        # Given
+        hook = SparkSubmitHook(conn_id='spark_standalone_cluster_client_mode',
+                               env_vars={"bar": "foo"})
+
+        # When
+        hook._build_spark_submit_command(self._spark_job_file)
+
+        # Then
+        self.assertEqual(hook._env, {"bar": "foo"})
+
+    def test_resolve_spark_submit_env_vars_standalone_cluster_mode(self):
+
+        def env_vars_exception_in_standalone_cluster_mode():
+            # Given
+            hook = SparkSubmitHook(conn_id='spark_standalone_cluster',
+                                   env_vars={"bar": "foo"})
+
+            # When
+            hook._build_spark_submit_command(self._spark_job_file)
+
+        # Then
+        self.assertRaises(AirflowException,
+                          env_vars_exception_in_standalone_cluster_mode)
+
+    def test_resolve_spark_submit_env_vars_yarn(self):
+        # Given
+        hook = SparkSubmitHook(conn_id='spark_yarn_cluster',
+                               env_vars={"bar": "foo"})
+
+        # When
+        cmd = hook._build_spark_submit_command(self._spark_job_file)
+
+        # Then
+        self.assertEqual(cmd[4], "spark.yarn.appMasterEnv.bar=foo")
+
+    def test_resolve_spark_submit_env_vars_k8s(self):
+        # Given
+        hook = SparkSubmitHook(conn_id='spark_k8s_cluster',
+                               env_vars={"bar": "foo"})
+
+        # When
+        cmd = hook._build_spark_submit_command(self._spark_job_file)
+
+        # Then
+        self.assertEqual(cmd[4], "spark.kubernetes.driverEnv.bar=foo")
+
     def test_process_spark_submit_log_yarn(self):
         # Given
         hook = SparkSubmitHook(conn_id='spark_yarn_cluster')

Reply via email to