This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0b850a97e8 Yandex dataproc deduce default service account (#35059)
0b850a97e8 is described below
commit 0b850a97e8972f9c2cef9cb5ab520b6018033cdb
Author: Peter Reznikov <[email protected]>
AuthorDate: Fri Nov 3 18:52:28 2023 +0300
Yandex dataproc deduce default service account (#35059)
---------
Co-authored-by: Petr Reznikov <[email protected]>
---
airflow/providers/yandex/hooks/yandex.py | 55 +++++++++++++++-------
.../yandex/operators/yandexcloud_dataproc.py | 2 +-
airflow/providers/yandex/provider.yaml | 12 +++++
tests/providers/yandex/hooks/test_yandex.py | 12 +++++
4 files changed, 63 insertions(+), 18 deletions(-)
diff --git a/airflow/providers/yandex/hooks/yandex.py
b/airflow/providers/yandex/hooks/yandex.py
index b54de6f319..e91b287bfa 100644
--- a/airflow/providers/yandex/hooks/yandex.py
+++ b/airflow/providers/yandex/hooks/yandex.py
@@ -88,14 +88,21 @@ class YandexCloudBaseHook(BaseHook):
@classmethod
def provider_user_agent(cls) -> str | None:
"""Construct User-Agent from Airflow core & provider package
versions."""
- import airflow
+ from airflow import __version__ as airflow_version
+ from airflow.configuration import conf
from airflow.providers_manager import ProvidersManager
try:
manager = ProvidersManager()
provider_name = manager.hooks[cls.conn_type].package_name # type:
ignore[union-attr]
provider = manager.providers[provider_name]
- return f"apache-airflow/{airflow.__version__}
{provider_name}/{provider.version}"
+ return " ".join(
+ (
+ conf.get("yandex", "sdk_user_agent_prefix", fallback=""),
+ f"apache-airflow/{airflow_version}",
+ f"{provider_name}/{provider.version}",
+ )
+ ).strip()
except KeyError:
warnings.warn(f"Hook '{cls.hook_name}' info is not initialized in
airflow.ProviderManager")
return None
@@ -115,6 +122,7 @@ class YandexCloudBaseHook(BaseHook):
yandex_conn_id: str | None = None,
default_folder_id: str | None = None,
default_public_ssh_key: str | None = None,
+ default_service_account_id: str | None = None,
) -> None:
super().__init__()
if connection_id:
@@ -129,31 +137,44 @@ class YandexCloudBaseHook(BaseHook):
credentials = self._get_credentials()
sdk_config = self._get_endpoint()
self.sdk = yandexcloud.SDK(user_agent=self.provider_user_agent(),
**sdk_config, **credentials)
- self.default_folder_id = default_folder_id or
self._get_field("folder_id", False)
- self.default_public_ssh_key = default_public_ssh_key or
self._get_field("public_ssh_key", False)
+ self.default_folder_id = default_folder_id or
self._get_field("folder_id")
+ self.default_public_ssh_key = default_public_ssh_key or
self._get_field("public_ssh_key")
+ self.default_service_account_id = default_service_account_id or
self._get_service_account_id()
self.client = self.sdk.client
- def _get_credentials(self) -> dict[str, Any]:
- service_account_json_path =
self._get_field("service_account_json_path", False)
- service_account_json = self._get_field("service_account_json", False)
- oauth_token = self._get_field("oauth", False)
- if not (service_account_json or oauth_token or
service_account_json_path):
- raise AirflowException(
- "No credentials are found in connection. Specify either
service account "
- "authentication JSON or user OAuth token in Yandex.Cloud
connection"
- )
+ def _get_service_account_key(self) -> dict[str, str] | None:
+ service_account_json = self._get_field("service_account_json")
+ service_account_json_path =
self._get_field("service_account_json_path")
if service_account_json_path:
with open(service_account_json_path) as infile:
service_account_json = infile.read()
if service_account_json:
- service_account_key = json.loads(service_account_json)
- return {"service_account_key": service_account_key}
- else:
+ return json.loads(service_account_json)
+ return None
+
+ def _get_service_account_id(self) -> str | None:
+ sa_key = self._get_service_account_key()
+ if sa_key:
+ return sa_key.get("service_account_id")
+ return None
+
+ def _get_credentials(self) -> dict[str, Any]:
+ oauth_token = self._get_field("oauth")
+ if oauth_token:
return {"token": oauth_token}
+ service_account_key = self._get_service_account_key()
+ if service_account_key:
+ return {"service_account_key": service_account_key}
+
+ raise AirflowException(
+ "No credentials are found in connection. Specify either service
account "
+ "authentication JSON or user OAuth token in Yandex.Cloud
connection"
+ )
+
def _get_endpoint(self) -> dict[str, str]:
sdk_config = {}
- endpoint = self._get_field("endpoint", None)
+ endpoint = self._get_field("endpoint")
if endpoint:
sdk_config["endpoint"] = endpoint
return sdk_config
diff --git a/airflow/providers/yandex/operators/yandexcloud_dataproc.py
b/airflow/providers/yandex/operators/yandexcloud_dataproc.py
index dfd3a07fe4..de4ea6e9c2 100644
--- a/airflow/providers/yandex/operators/yandexcloud_dataproc.py
+++ b/airflow/providers/yandex/operators/yandexcloud_dataproc.py
@@ -194,7 +194,7 @@ class DataprocCreateClusterOperator(BaseOperator):
services=self.services,
s3_bucket=self.s3_bucket,
zone=self.zone,
- service_account_id=self.service_account_id,
+ service_account_id=self.service_account_id or
self.hook.default_service_account_id,
masternode_resource_preset=self.masternode_resource_preset,
masternode_disk_size=self.masternode_disk_size,
masternode_disk_type=self.masternode_disk_type,
diff --git a/airflow/providers/yandex/provider.yaml
b/airflow/providers/yandex/provider.yaml
index fcee1ecc61..5217b6d2b8 100644
--- a/airflow/providers/yandex/provider.yaml
+++ b/airflow/providers/yandex/provider.yaml
@@ -70,3 +70,15 @@ hooks:
connection-types:
- hook-class-name: airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook
connection-type: yandexcloud
+
+config:
+ yandex:
+ description: This section contains settings for Yandex Cloud integration.
+ options:
+ sdk_user_agent_prefix:
+ description: |
+ Prefix for User-Agent header in Yandex.Cloud SDK requests
+ version_added: 3.6.0
+ type: string
+ example: ~
+ default: ""
diff --git a/tests/providers/yandex/hooks/test_yandex.py
b/tests/providers/yandex/hooks/test_yandex.py
index f802ba8d87..23b460dabf 100644
--- a/tests/providers/yandex/hooks/test_yandex.py
+++ b/tests/providers/yandex/hooks/test_yandex.py
@@ -25,6 +25,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook
+from tests.test_utils.config import conf_vars
class TestYandexHook:
@@ -139,6 +140,17 @@ class TestYandexHook:
assert hook._get_endpoint() == {}
+ @mock.patch("airflow.hooks.base.BaseHook.get_connection")
+
@mock.patch("airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials")
+ def test_sdk_user_agent(self, get_credentials_mock, get_connection_mock):
+ get_connection_mock.return_value =
mock.Mock(connection_id="yandexcloud_default", extra_dejson="{}")
+ get_credentials_mock.return_value = {"token": 122323}
+ sdk_prefix = "MyAirflow"
+
+ with conf_vars({("yandex", "sdk_user_agent_prefix"): sdk_prefix}):
+ hook = YandexCloudBaseHook()
+ assert hook.sdk._channels._client_user_agent.startswith(sdk_prefix)
+
@pytest.mark.parametrize(
"uri",
[