This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 30817a5c6d support iam token from metadata, simplify code (#38411)
30817a5c6d is described below

commit 30817a5c6df1be3ec080ff1c542899092679768f
Author: uzhastik <uz...@ydb.tech>
AuthorDate: Sat Mar 23 00:30:33 2024 +0300

    support iam token from metadata, simplify code (#38411)
---
 airflow/providers/yandex/hooks/yq.py    | 54 +++++++++++----------------------
 airflow/providers/yandex/provider.yaml  |  3 --
 generated/provider_dependencies.json    |  2 --
 pyproject.toml                          |  2 --
 tests/providers/yandex/hooks/test_yq.py | 49 +++++++++++++++++++++++++-----
 5 files changed, 59 insertions(+), 51 deletions(-)

diff --git a/airflow/providers/yandex/hooks/yq.py 
b/airflow/providers/yandex/hooks/yq.py
index 963709d89b..37f7550df6 100644
--- a/airflow/providers/yandex/hooks/yq.py
+++ b/airflow/providers/yandex/hooks/yq.py
@@ -16,16 +16,14 @@
 # under the License.
 from __future__ import annotations
 
-import time
 from datetime import timedelta
 from typing import Any
 
-import jwt
-import requests
-from urllib3.util.retry import Retry
+import yandexcloud
+import yandexcloud._auth_fabric as auth_fabric
+from yandex.cloud.iam.v1.iam_token_service_pb2_grpc import IamTokenServiceStub
 from yandex_query_client import YQHttpClient, YQHttpClientConfig
 
-from airflow.exceptions import AirflowException
 from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook
 from airflow.providers.yandex.utils.user_agent import provider_user_agent
 
@@ -98,35 +96,17 @@ class YQHook(YandexCloudBaseHook):
         return self.client.compose_query_web_link(query_id)
 
     def _get_iam_token(self) -> str:
-        if "token" in self.credentials:
-            return self.credentials["token"]
-        if "service_account_key" in self.credentials:
-            return 
YQHook._resolve_service_account_key(self.credentials["service_account_key"])
-        raise AirflowException(f"Unknown credentials type, available keys 
{self.credentials.keys()}")
-
-    @staticmethod
-    def _resolve_service_account_key(sa_info: dict) -> str:
-        with YQHook._create_session() as session:
-            api = "https://iam.api.cloud.yandex.net/iam/v1/tokens";
-            now = int(time.time())
-            payload = {"aud": api, "iss": sa_info["service_account_id"], 
"iat": now, "exp": now + 360}
-
-            encoded_token = jwt.encode(
-                payload, sa_info["private_key"], algorithm="PS256", 
headers={"kid": sa_info["id"]}
-            )
-
-            data = {"jwt": encoded_token}
-            iam_response = session.post(api, json=data)
-            iam_response.raise_for_status()
-
-            return iam_response.json()["iamToken"]
-
-    @staticmethod
-    def _create_session() -> requests.Session:
-        session = requests.Session()
-        session.verify = False
-        retry = Retry(backoff_factor=0.3, total=10)
-        session.mount("http://";, 
requests.adapters.HTTPAdapter(max_retries=retry))
-        session.mount("https://";, 
requests.adapters.HTTPAdapter(max_retries=retry))
-
-        return session
+        iam_token = self.credentials.get("token")
+        if iam_token is not None:
+            return iam_token
+
+        service_account_key = self.credentials.get("service_account_key")
+        # if service_account_key is None metadata server will be used
+        token_requester = 
auth_fabric.get_auth_token_requester(service_account_key=service_account_key)
+
+        if service_account_key is None:
+            return token_requester.get_token()
+
+        sdk = yandexcloud.SDK()
+        client = sdk.client(IamTokenServiceStub)
+        return client.Create(token_requester.get_token_request()).iam_token
diff --git a/airflow/providers/yandex/provider.yaml 
b/airflow/providers/yandex/provider.yaml
index 0135ac3fb4..df700127c4 100644
--- a/airflow/providers/yandex/provider.yaml
+++ b/airflow/providers/yandex/provider.yaml
@@ -50,9 +50,6 @@ dependencies:
   - apache-airflow>=2.6.0
   - yandexcloud>=0.228.0
   - yandex-query-client>=0.1.2
-  - python-dateutil>=2.8.0
-  # Requests 3 if it will be released, will be heavily breaking.
-  - requests>=2.27.0,<3
 
 integrations:
   - integration-name: Yandex.Cloud
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index fe9848069d..4f110f9918 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -1180,8 +1180,6 @@
   "yandex": {
     "deps": [
       "apache-airflow>=2.6.0",
-      "python-dateutil>=2.8.0",
-      "requests>=2.27.0,<3",
       "yandex-query-client>=0.1.2",
       "yandexcloud>=0.228.0"
     ],
diff --git a/pyproject.toml b/pyproject.toml
index 89d7496c4a..ace1b0800a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -975,8 +975,6 @@ weaviate = [ # source: 
airflow/providers/weaviate/provider.yaml
   "weaviate-client>=3.24.2",
 ]
 yandex = [ # source: airflow/providers/yandex/provider.yaml
-  "python-dateutil>=2.8.0",
-  "requests>=2.27.0,<3",
   "yandex-query-client>=0.1.2",
   "yandexcloud>=0.228.0",
 ]
diff --git a/tests/providers/yandex/hooks/test_yq.py 
b/tests/providers/yandex/hooks/test_yq.py
index 3b3db91dd1..c378c65347 100644
--- a/tests/providers/yandex/hooks/test_yq.py
+++ b/tests/providers/yandex/hooks/test_yq.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import json
 from datetime import timedelta
 from unittest import mock
 
@@ -26,6 +27,7 @@ from airflow.models import Connection
 from airflow.providers.yandex.hooks.yq import YQHook
 
 OAUTH_TOKEN = "my_oauth_token"
+IAM_TOKEN = "my_iam_token"
 SERVICE_ACCOUNT_AUTH_KEY_JSON = """{"id":"my_id", 
"service_account_id":"my_sa1", "private_key":"my_pk"}"""
 
 
@@ -34,6 +36,18 @@ class DummySDK:
         self.client = None
 
 
+class DummyTokenRequester:
+    def get_token(self) -> str:
+        return IAM_TOKEN
+
+    def get_token_request(self) -> str:
+        return "my_dummy_request"
+
+
+class DummyCreateTokenResponse:
+    iam_token = "zzz"
+
+
 class TestYandexCloudYqHook:
     def _init_hook(self):
         with mock.patch("airflow.hooks.base.BaseHook.get_connection") as 
mock_get_connection:
@@ -68,18 +82,33 @@ class TestYandexCloudYqHook:
             m.assert_called_once_with("query1")
 
     @responses.activate()
-    @mock.patch("yandexcloud.SDK")
-    @mock.patch("jwt.encode")
-    def test_select_results(self, mock_jwt, mock_sdk):
+    @mock.patch("yandexcloud._auth_fabric.get_auth_token_requester", 
return_value=DummyTokenRequester())
+    def test_metadata_token_usage(self, mock_get_auth_token_requester):
         responses.post(
-            "https://iam.api.cloud.yandex.net/iam/v1/tokens";,
-            json={"iamToken": "super_token"},
+            "https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries";,
+            match=[
+                matchers.header_matcher(
+                    {"Content-Type": "application/json", "Authorization": 
f"Bearer {IAM_TOKEN}"}
+                ),
+                matchers.query_param_matcher({"project": "my_folder_id"}),
+            ],
+            json={"id": "query1"},
             status=200,
         )
 
-        mock_jwt.return_value = "zzzz"
-        mock_sdk.return_value = DummySDK()
+        self.connection = Connection(extra={})
+        self._init_hook()
+        query_id = self.hook.create_query(query_text="select 777", name="my 
query")
+        assert query_id == "query1"
 
+    @mock.patch(
+        
"yandex.cloud.iam.v1.iam_token_service_pb2_grpc.IamTokenServiceStub.Create",
+        create=True,
+        new_callable=mock.PropertyMock,
+    )
+    @mock.patch("yandexcloud._auth_fabric.__validate_service_account_key")
+    @mock.patch("yandexcloud._auth_fabric.get_auth_token_requester", 
return_value=DummyTokenRequester())
+    def test_select_results(self, mock_get_auth_token_requester, 
mock_validate, mock_create_token):
         with mock.patch.multiple(
             "yandex_query_client.YQHttpClient",
             create_query=mock.DEFAULT,
@@ -90,6 +119,12 @@ class TestYandexCloudYqHook:
             stop_query=mock.DEFAULT,
         ) as mocks:
             self._init_hook()
+            mock_validate.assert_called()
+            mock_create_token.assert_called()
+            mock_get_auth_token_requester.assert_called_once_with(
+                service_account_key=json.loads(SERVICE_ACCOUNT_AUTH_KEY_JSON)
+            )
+
             mocks["create_query"].return_value = "query1"
             mocks["wait_query_to_succeed"].return_value = 2
             mocks["get_query_all_result_sets"].return_value = {"x": 765}

Reply via email to