This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 645d52f129 Add `use_krb5ccache` option to `SparkSubmitHook` (#34386)
645d52f129 is described below
commit 645d52f1298c49b2111d058971e1a9f159f1e257
Author: zeotuan <[email protected]>
AuthorDate: Sat Oct 21 20:19:29 2023 +1100
Add `use_krb5ccache` option to `SparkSubmitHook` (#34386)
---
.../providers/apache/spark/hooks/spark_submit.py | 33 +++++++++++++++++++++-
.../apache/spark/hooks/test_spark_submit.py | 27 +++++++++++++++++-
2 files changed, 58 insertions(+), 2 deletions(-)
diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py
b/airflow/providers/apache/spark/hooks/spark_submit.py
index 75c13c8099..d519eb3e6e 100644
--- a/airflow/providers/apache/spark/hooks/spark_submit.py
+++ b/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -78,6 +78,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
:param verbose: Whether to pass the verbose flag to spark-submit process
for debugging
:param spark_binary: The command to use for spark submit.
Some distros may use spark2-submit or spark3-submit.
+ :param use_krb5ccache: if True, configure spark to use ticket cache
instead of relying
+ on keytab for Kerberos login
"""
conn_name_attr = "conn_id"
@@ -120,6 +122,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
env_vars: dict[str, Any] | None = None,
verbose: bool = False,
spark_binary: str | None = None,
+ *,
+ use_krb5ccache: bool = False,
) -> None:
super().__init__()
self._conf = conf or {}
@@ -138,7 +142,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
self._executor_memory = executor_memory
self._driver_memory = driver_memory
self._keytab = keytab
- self._principal = principal
+ self._principal = self._resolve_kerberos_principal(principal) if
use_krb5ccache else principal
+ self._use_krb5ccache = use_krb5ccache
self._proxy_user = proxy_user
self._name = name
self._num_executors = num_executors
@@ -317,6 +322,12 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
connection_cmd += ["--keytab", self._keytab]
if self._principal:
connection_cmd += ["--principal", self._principal]
+ if self._use_krb5ccache:
+ if not os.getenv("KRB5CCNAME"):
+ raise AirflowException(
+ "KRB5CCNAME environment variable required to use ticket
ccache is missing."
+ )
+ connection_cmd += ["--conf",
"spark.kerberos.renewal.credentials=ccache"]
if self._proxy_user:
connection_cmd += ["--proxy-user", self._proxy_user]
if self._name:
@@ -383,6 +394,26 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
return connection_cmd
+ def _resolve_kerberos_principal(self, principal: str | None) -> str:
+ """Resolve kerberos principal if airflow > 2.8.
+
+ TODO: delete when min airflow version >= 2.8 and import directly from
airflow.security.kerberos
+ """
+ from packaging.version import Version
+
+ from airflow.version import version
+
+ if Version(version) < Version("2.8"):
+ from airflow.utils.net import get_hostname
+
+ return principal or airflow_conf.get_mandatory_value("kerberos",
"principal").replace(
+ "_HOST", get_hostname()
+ )
+ else:
+ from airflow.security.kerberos import get_kerberos_principle
+
+ return get_kerberos_principle(principal)
+
def submit(self, application: str = "", **kwargs: Any) -> None:
"""
Remote Popen to execute the spark-submit job.
diff --git a/tests/providers/apache/spark/hooks/test_spark_submit.py
b/tests/providers/apache/spark/hooks/test_spark_submit.py
index d3feccb081..9bf828e94d 100644
--- a/tests/providers/apache/spark/hooks/test_spark_submit.py
+++ b/tests/providers/apache/spark/hooks/test_spark_submit.py
@@ -61,6 +61,7 @@ class TestSparkSubmitHook:
"args should keep embedded spaces",
"baz",
],
+ "use_krb5ccache": True,
}
@staticmethod
@@ -141,7 +142,10 @@ class TestSparkSubmitHook:
)
)
- def test_build_spark_submit_command(self):
+ @patch(
+ "airflow.providers.apache.spark.hooks.spark_submit.os.getenv",
return_value="/tmp/airflow_krb5_ccache"
+ )
+ def test_build_spark_submit_command(self, mock_get_env):
# Given
hook = SparkSubmitHook(**self._config)
@@ -183,6 +187,8 @@ class TestSparkSubmitHook:
"privileged_user.keytab",
"--principal",
"user/[email protected]",
+ "--conf",
+ "spark.kerberos.renewal.credentials=ccache",
"--proxy-user",
"sample_user",
"--name",
@@ -200,6 +206,25 @@ class TestSparkSubmitHook:
"baz",
]
assert expected_build_cmd == cmd
+ mock_get_env.assert_called_with("KRB5CCNAME")
+
+ @patch("airflow.configuration.conf.get_mandatory_value")
+ def
test_resolve_spark_submit_env_vars_use_krb5ccache_missing_principal(self,
mock_get_madantory_value):
+ mock_principle = "airflow"
+ mock_get_madantory_value.return_value = mock_principle
+ hook = SparkSubmitHook(conn_id="spark_yarn_cluster", principal=None,
use_krb5ccache=True)
+ mock_get_madantory_value.assert_called_with("kerberos", "principal")
+ assert hook._principal == mock_principle
+
+ def
test_resolve_spark_submit_env_vars_use_krb5ccache_missing_KRB5CCNAME_env(self):
+ hook = SparkSubmitHook(
+ conn_id="spark_yarn_cluster", principal="user/[email protected]",
use_krb5ccache=True
+ )
+ with pytest.raises(
+ AirflowException,
+ match="KRB5CCNAME environment variable required to use ticket
ccache is missing.",
+ ):
+ hook._build_spark_submit_command(self._spark_job_file)
def test_build_track_driver_status_command(self):
# note this function is only relevant for spark setup matching below
condition