mik-laj commented on a change in pull request #16771:
URL: https://github.com/apache/airflow/pull/16771#discussion_r663051792



##########
File path: airflow/providers/amazon/aws/hooks/base_aws.py
##########
@@ -97,57 +99,61 @@ def _create_basic_session(self, session_kwargs: Dict[str, 
Any]) -> boto3.session
             **session_kwargs,
         )
 
-    def _impersonate_to_role(
-        self, role_arn: str, session: boto3.session.Session, session_kwargs: 
Dict[str, Any]
-    ) -> boto3.session.Session:
-        assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {})
-        assume_role_method = self.extra_config.get('assume_role_method')
+    def _create_session_with_assume_role(self, session_kwargs: Dict[str, Any]) 
-> boto3.session.Session:
+        assume_role_method = self.extra_config.get('assume_role_method', 
'assume_role')
         self.log.info("assume_role_method=%s", assume_role_method)
-        if not assume_role_method or assume_role_method == 'assume_role':
-            sts_client = session.client("sts", config=self.config)
-            sts_response = self._assume_role(
-                sts_client=sts_client, role_arn=role_arn, 
assume_role_kwargs=assume_role_kwargs
-            )
-        elif assume_role_method == 'assume_role_with_saml':
-            sts_client = session.client("sts", config=self.config)
-            sts_response = self._assume_role_with_saml(
-                sts_client=sts_client, role_arn=role_arn, 
assume_role_kwargs=assume_role_kwargs
-            )
-        elif assume_role_method == 'assume_role_with_web_identity':
-            botocore_session = self._assume_role_with_web_identity(
-                role_arn=role_arn,
-                assume_role_kwargs=assume_role_kwargs,
-                base_session=session._session,
-            )
-            return boto3.session.Session(
-                region_name=session.region_name,
-                botocore_session=botocore_session,
-                **session_kwargs,
-            )
-        else:
+        supported_methods = ['assume_role', 'assume_role_with_saml', 
'assume_role_with_web_identity']
+        if assume_role_method not in supported_methods:
             raise NotImplementedError(
                 f'assume_role_method={assume_role_method} in Connection 
{self.conn.conn_id} Extra.'
-                'Currently "assume_role" or "assume_role_with_saml" are 
supported.'
+                f'Currently {supported_methods} are supported.'
                 '(Exclude this setting will default to "assume_role").'
             )
-        # Use credentials retrieved from STS
-        credentials = sts_response["Credentials"]
-        aws_access_key_id = credentials["AccessKeyId"]
-        aws_secret_access_key = credentials["SecretAccessKey"]
-        aws_session_token = credentials["SessionToken"]
-        self.log.info(
-            "Creating session with aws_access_key_id=%s region_name=%s",
-            aws_access_key_id,
-            session.region_name,
-        )
-
-        return boto3.session.Session(
-            aws_access_key_id=aws_access_key_id,
-            aws_secret_access_key=aws_secret_access_key,
-            region_name=session.region_name,
-            aws_session_token=aws_session_token,
-            **session_kwargs,
-        )
+        if assume_role_method == 'assume_role_with_web_identity':
+            # Deferred credentials have no initial credentials
+            credential_fetcher = self._get_web_identity_credential_fetcher()
+            credentials = DeferredRefreshableCredentials(
+                method='assume-role-with-web-identity',
+                refresh_using=credential_fetcher.fetch_credentials,
+                time_fetcher=lambda: datetime.datetime.now(tz=tzlocal()),
+            )
+        else:
+            # Refreshable credentials do have initial credentials
+            credentials = RefreshableCredentials.create_from_metadata(
+                metadata=self._refresh_credentials(),
+                refresh_using=self._refresh_credentials,
+                method="sts-assume-role",
+            )
+        session = botocore.session.get_session()
+        session._credentials = credentials  # pylint: disable=protected-access
+        region_name = self.basic_session.region_name
+        session.set_config_variable("region", region_name)
+        return boto3.session.Session(botocore_session=session, 
**session_kwargs)
+
+    def _refresh_credentials(self) -> boto3.session.Session:

Review comment:
       I have a feeling this function returns a dictionary not 
`boto3.session.Session`. Am I right?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to