This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0b850a97e8 Yandex dataproc deduce default service account (#35059)
0b850a97e8 is described below

commit 0b850a97e8972f9c2cef9cb5ab520b6018033cdb
Author: Peter Reznikov <[email protected]>
AuthorDate: Fri Nov 3 18:52:28 2023 +0300

    Yandex dataproc deduce default service account (#35059)
    
    
    
    ---------
    
    Co-authored-by: Petr Reznikov <[email protected]>
---
 airflow/providers/yandex/hooks/yandex.py           | 55 +++++++++++++++-------
 .../yandex/operators/yandexcloud_dataproc.py       |  2 +-
 airflow/providers/yandex/provider.yaml             | 12 +++++
 tests/providers/yandex/hooks/test_yandex.py        | 12 +++++
 4 files changed, 63 insertions(+), 18 deletions(-)

diff --git a/airflow/providers/yandex/hooks/yandex.py 
b/airflow/providers/yandex/hooks/yandex.py
index b54de6f319..e91b287bfa 100644
--- a/airflow/providers/yandex/hooks/yandex.py
+++ b/airflow/providers/yandex/hooks/yandex.py
@@ -88,14 +88,21 @@ class YandexCloudBaseHook(BaseHook):
     @classmethod
     def provider_user_agent(cls) -> str | None:
         """Construct User-Agent from Airflow core & provider package 
versions."""
-        import airflow
+        from airflow import __version__ as airflow_version
+        from airflow.configuration import conf
         from airflow.providers_manager import ProvidersManager
 
         try:
             manager = ProvidersManager()
             provider_name = manager.hooks[cls.conn_type].package_name  # type: 
ignore[union-attr]
             provider = manager.providers[provider_name]
-            return f"apache-airflow/{airflow.__version__} 
{provider_name}/{provider.version}"
+            return " ".join(
+                (
+                    conf.get("yandex", "sdk_user_agent_prefix", fallback=""),
+                    f"apache-airflow/{airflow_version}",
+                    f"{provider_name}/{provider.version}",
+                )
+            ).strip()
         except KeyError:
             warnings.warn(f"Hook '{cls.hook_name}' info is not initialized in 
airflow.ProviderManager")
             return None
@@ -115,6 +122,7 @@ class YandexCloudBaseHook(BaseHook):
         yandex_conn_id: str | None = None,
         default_folder_id: str | None = None,
         default_public_ssh_key: str | None = None,
+        default_service_account_id: str | None = None,
     ) -> None:
         super().__init__()
         if connection_id:
@@ -129,31 +137,44 @@ class YandexCloudBaseHook(BaseHook):
         credentials = self._get_credentials()
         sdk_config = self._get_endpoint()
         self.sdk = yandexcloud.SDK(user_agent=self.provider_user_agent(), 
**sdk_config, **credentials)
-        self.default_folder_id = default_folder_id or 
self._get_field("folder_id", False)
-        self.default_public_ssh_key = default_public_ssh_key or 
self._get_field("public_ssh_key", False)
+        self.default_folder_id = default_folder_id or 
self._get_field("folder_id")
+        self.default_public_ssh_key = default_public_ssh_key or 
self._get_field("public_ssh_key")
+        self.default_service_account_id = default_service_account_id or 
self._get_service_account_id()
         self.client = self.sdk.client
 
-    def _get_credentials(self) -> dict[str, Any]:
-        service_account_json_path = 
self._get_field("service_account_json_path", False)
-        service_account_json = self._get_field("service_account_json", False)
-        oauth_token = self._get_field("oauth", False)
-        if not (service_account_json or oauth_token or 
service_account_json_path):
-            raise AirflowException(
-                "No credentials are found in connection. Specify either 
service account "
-                "authentication JSON or user OAuth token in Yandex.Cloud 
connection"
-            )
+    def _get_service_account_key(self) -> dict[str, str] | None:
+        service_account_json = self._get_field("service_account_json")
+        service_account_json_path = 
self._get_field("service_account_json_path")
         if service_account_json_path:
             with open(service_account_json_path) as infile:
                 service_account_json = infile.read()
         if service_account_json:
-            service_account_key = json.loads(service_account_json)
-            return {"service_account_key": service_account_key}
-        else:
+            return json.loads(service_account_json)
+        return None
+
+    def _get_service_account_id(self) -> str | None:
+        sa_key = self._get_service_account_key()
+        if sa_key:
+            return sa_key.get("service_account_id")
+        return None
+
+    def _get_credentials(self) -> dict[str, Any]:
+        oauth_token = self._get_field("oauth")
+        if oauth_token:
             return {"token": oauth_token}
 
+        service_account_key = self._get_service_account_key()
+        if service_account_key:
+            return {"service_account_key": service_account_key}
+
+        raise AirflowException(
+            "No credentials are found in connection. Specify either service 
account "
+            "authentication JSON or user OAuth token in Yandex.Cloud 
connection"
+        )
+
     def _get_endpoint(self) -> dict[str, str]:
         sdk_config = {}
-        endpoint = self._get_field("endpoint", None)
+        endpoint = self._get_field("endpoint")
         if endpoint:
             sdk_config["endpoint"] = endpoint
         return sdk_config
diff --git a/airflow/providers/yandex/operators/yandexcloud_dataproc.py 
b/airflow/providers/yandex/operators/yandexcloud_dataproc.py
index dfd3a07fe4..de4ea6e9c2 100644
--- a/airflow/providers/yandex/operators/yandexcloud_dataproc.py
+++ b/airflow/providers/yandex/operators/yandexcloud_dataproc.py
@@ -194,7 +194,7 @@ class DataprocCreateClusterOperator(BaseOperator):
             services=self.services,
             s3_bucket=self.s3_bucket,
             zone=self.zone,
-            service_account_id=self.service_account_id,
+            service_account_id=self.service_account_id or 
self.hook.default_service_account_id,
             masternode_resource_preset=self.masternode_resource_preset,
             masternode_disk_size=self.masternode_disk_size,
             masternode_disk_type=self.masternode_disk_type,
diff --git a/airflow/providers/yandex/provider.yaml 
b/airflow/providers/yandex/provider.yaml
index fcee1ecc61..5217b6d2b8 100644
--- a/airflow/providers/yandex/provider.yaml
+++ b/airflow/providers/yandex/provider.yaml
@@ -70,3 +70,15 @@ hooks:
 connection-types:
   - hook-class-name: airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook
     connection-type: yandexcloud
+
+config:
+  yandex:
+    description: This section contains settings for Yandex Cloud integration.
+    options:
+      sdk_user_agent_prefix:
+        description: |
+          Prefix for User-Agent header in Yandex.Cloud SDK requests
+        version_added: 3.6.0
+        type: string
+        example: ~
+        default: ""
diff --git a/tests/providers/yandex/hooks/test_yandex.py 
b/tests/providers/yandex/hooks/test_yandex.py
index f802ba8d87..23b460dabf 100644
--- a/tests/providers/yandex/hooks/test_yandex.py
+++ b/tests/providers/yandex/hooks/test_yandex.py
@@ -25,6 +25,7 @@ import pytest
 
 from airflow.exceptions import AirflowException
 from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook
+from tests.test_utils.config import conf_vars
 
 
 class TestYandexHook:
@@ -139,6 +140,17 @@ class TestYandexHook:
 
         assert hook._get_endpoint() == {}
 
+    @mock.patch("airflow.hooks.base.BaseHook.get_connection")
+    
@mock.patch("airflow.providers.yandex.hooks.yandex.YandexCloudBaseHook._get_credentials")
+    def test_sdk_user_agent(self, get_credentials_mock, get_connection_mock):
+        get_connection_mock.return_value = 
mock.Mock(connection_id="yandexcloud_default", extra_dejson="{}")
+        get_credentials_mock.return_value = {"token": 122323}
+        sdk_prefix = "MyAirflow"
+
+        with conf_vars({("yandex", "sdk_user_agent_prefix"): sdk_prefix}):
+            hook = YandexCloudBaseHook()
+            assert hook.sdk._channels._client_user_agent.startswith(sdk_prefix)
+
     @pytest.mark.parametrize(
         "uri",
         [

Reply via email to