This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 57f203251b FIX AWS deferrable operators by using AioCredentials when 
using `assume_role` (#32733)
57f203251b is described below

commit 57f203251b223550d6e7bb717910109af9aeed29
Author: Rishi Kulkarni <[email protected]>
AuthorDate: Sat Jul 22 13:28:30 2023 -0400

    FIX AWS deferrable operators by using AioCredentials when using 
`assume_role` (#32733)
    
    * FIX: deferrable operators now use AioCredentials
    
    Fixes airflow.providers.amazong.aws.hooks.base_aws.BaseSessionFactory feeds 
synchronous credentials to aiobotocore when using `assume_role` #32732
    
    * formatting
    
    * use dict unpacking
    
    * TEST: add test that checks that credentials._refresh is a coroutine
    
    * TEST: assert that `credentials.get_frozen_credentials` is a coroutine
---
 airflow/providers/amazon/aws/hooks/base_aws.py    | 35 ++++++++++++++++-------
 tests/providers/amazon/aws/hooks/test_base_aws.py |  4 +++
 2 files changed, 28 insertions(+), 11 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py 
b/airflow/providers/amazon/aws/hooks/base_aws.py
index aaec87dd09..94782401ac 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -197,22 +197,35 @@ class BaseSessionFactory(LoggingMixin):
     def _create_session_with_assume_role(
         self, session_kwargs: dict[str, Any], deferrable: bool = False
     ) -> boto3.session.Session:
-
         if self.conn.assume_role_method == "assume_role_with_web_identity":
             # Deferred credentials have no initial credentials
             credential_fetcher = self._get_web_identity_credential_fetcher()
-            credentials = botocore.credentials.DeferredRefreshableCredentials(
-                method="assume-role-with-web-identity",
-                refresh_using=credential_fetcher.fetch_credentials,
-                time_fetcher=lambda: datetime.datetime.now(tz=tzlocal()),
-            )
+
+            params = {
+                "method": "assume-role-with-web-identity",
+                "refresh_using": credential_fetcher.fetch_credentials,
+                "time_fetcher": lambda: datetime.datetime.now(tz=tzlocal()),
+            }
+
+            if deferrable:
+                from aiobotocore.credentials import 
AioDeferredRefreshableCredentials
+
+                credentials = AioDeferredRefreshableCredentials(**params)
+            else:
+                credentials = 
botocore.credentials.DeferredRefreshableCredentials(**params)
         else:
             # Refreshable credentials do have initial credentials
-            credentials = 
botocore.credentials.RefreshableCredentials.create_from_metadata(
-                metadata=self._refresh_credentials(),
-                refresh_using=self._refresh_credentials,
-                method="sts-assume-role",
-            )
+            params = {
+                "metadata": self._refresh_credentials(),
+                "refresh_using": self._refresh_credentials,
+                "method": "sts-assume-role",
+            }
+            if deferrable:
+                from aiobotocore.credentials import AioRefreshableCredentials
+
+                credentials = 
AioRefreshableCredentials.create_from_metadata(**params)
+            else:
+                credentials = 
botocore.credentials.RefreshableCredentials.create_from_metadata(**params)
 
         if deferrable:
             from aiobotocore.session import get_session as async_get_session
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py 
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 6f84e4ad24..8ba8b6b812 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+import inspect
 import json
 import os
 from base64 import b64encode
@@ -332,6 +333,9 @@ class TestSessionFactory:
             # Validate method of botocore credentials provider.
             # It shouldn't be 'explicit' which refers in this case to initial 
credentials.
             credentials = await session.get_credentials()
+
+            assert 
inspect.iscoroutinefunction(credentials.get_frozen_credentials)
+
             assert credentials.method == "sts-assume-role"
 
 

Reply via email to