This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a6315c2f4e Amazon Provider Package user agent (#27823)
a6315c2f4e is described below

commit a6315c2f4ed68c822d0109f9609c1518e0bde94e
Author: D. Ferruzzi <[email protected]>
AuthorDate: Thu Dec 8 10:51:20 2022 -0800

    Amazon Provider Package user agent (#27823)
---
 airflow/providers/amazon/aws/hooks/base_aws.py    | 134 +++++++++++++++++++---
 tests/providers/amazon/aws/hooks/test_base_aws.py |  64 ++++++++++-
 2 files changed, 177 insertions(+), 21 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py 
b/airflow/providers/amazon/aws/hooks/base_aws.py
index ec06eb22e5..6ff5f21467 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -25,9 +25,13 @@ This module contains Base AWS Hook.
 from __future__ import annotations
 
 import datetime
+import inspect
 import json
 import logging
+import os
+import uuid
 import warnings
+from copy import deepcopy
 from functools import wraps
 from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
 
@@ -47,6 +51,7 @@ from airflow.configuration import conf
 from airflow.exceptions import AirflowException, AirflowNotFoundException
 from airflow.hooks.base import BaseHook
 from airflow.providers.amazon.aws.utils.connection_wrapper import 
AwsConnectionWrapper
+from airflow.providers_manager import ProvidersManager
 from airflow.utils.helpers import exactly_one
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.log.secrets_masker import mask_secret
@@ -409,6 +414,92 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         self._config = config
         self._verify = verify
 
+    @classmethod
+    def _get_provider_version(cls) -> str:
+        """Checks the Providers Manager for the package version."""
+        try:
+            manager = ProvidersManager()
+            hook = manager.hooks[cls.conn_type]
+            if not hook:
+                # This gets caught immediately, but without it MyPy complains
+                # Item "None" of "Optional[HookInfo]" has no attribute 
"package_name"
+                # on the following line and static checks fail.
+                raise ValueError(f"Hook info for {cls.conn_type} not found in 
the Provider Manager.")
+            provider = manager.providers[hook.package_name]
+            return provider.version
+        except Exception:
+            # Under no condition should an error here ever cause an issue for 
the user.
+            return "Unknown"
+
+    @staticmethod
+    def _find_class_name(target_function_name: str) -> str:
+        """
+        Given a frame off the stack, return the name of the class which made 
the call.
+        Note: This method may raise a ValueError or an IndexError, but the 
calling
+        method is catching and handling those.
+        """
+        stack = inspect.stack()
+        # Find the index of the most recent frame which called the provided 
function name.
+        target_frame_index = [frame.function for frame in 
stack].index(target_function_name)
+        # Pull that frame off the stack.
+        target_frame = stack[target_frame_index][0]
+        # Get the local variables for that frame.
+        frame_variables = target_frame.f_locals["self"]
+        # Get the class object for that frame.
+        frame_class_object = frame_variables.__class__
+        # Return the name of the class object.
+        return frame_class_object.__name__
+
+    def _get_caller(self, target_function_name: str = "execute") -> str:
+        """Given a function name, walk the stack and return the name of the 
class which called it last."""
+        try:
+            caller = self._find_class_name(target_function_name)
+            if caller == "BaseSensorOperator":
+                # If the result is a BaseSensorOperator, then look for 
whatever last called "poke".
+                return self._get_caller("poke")
+            return caller
+        except Exception:
+            # Under no condition should an error here ever cause an issue for 
the user.
+            return "Unknown"
+
+    @staticmethod
+    def _generate_dag_key() -> str:
+        """
+        The Object Identifier (OID) namespace is used to salt the dag_id value.
+        That salted value is used to generate a SHA-1 hash which, by 
definition,
+        can not (reasonably) be reversed.  No personal data can be inferred or
+        extracted from the resulting UUID.
+        """
+        try:
+            dag_id = os.environ["AIRFLOW_CTX_DAG_ID"]
+            return str(uuid.uuid5(uuid.NAMESPACE_OID, dag_id))
+        except Exception:
+            # Under no condition should an error here ever cause an issue for 
the user.
+            return "00000000-0000-0000-0000-000000000000"
+
+    @staticmethod
+    def _get_airflow_version() -> str:
+        """Fetch and return the current Airflow version."""
+        try:
+            # This can be a circular import under specific configurations.
+            # Importing locally to either avoid or catch it if it does happen.
+            from airflow import __version__ as airflow_version
+
+            return airflow_version
+        except Exception:
+            # Under no condition should an error here ever cause an issue for 
the user.
+            return "Unknown"
+
+    def _generate_user_agent_extra_field(self, existing_user_agent_extra: str) 
-> str:
+        user_agent_extra_values = [
+            f"Airflow/{self._get_airflow_version()}",
+            f"AmPP/{self._get_provider_version()}",
+            f"Caller/{self._get_caller()}",
+            f"DagRunKey/{self._generate_dag_key()}",
+            existing_user_agent_extra or "",
+        ]
+        return " ".join(user_agent_extra_values).strip()
+
     @cached_property
     def conn_config(self) -> AwsConnectionWrapper:
         """Get the Airflow Connection object and wrap it in helper (cached)."""
@@ -436,9 +527,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
         return self.conn_config.region_name
 
     @property
-    def config(self) -> Config | None:
+    def config(self) -> Config:
         """Configuration for botocore client read-only property."""
-        return self.conn_config.botocore_config
+        return self.conn_config.botocore_config or botocore.config.Config()
 
     @property
     def verify(self) -> bool | str | None:
@@ -451,6 +542,23 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
             conn=self.conn_config, region_name=region_name, config=self.config
         ).create_session()
 
+    def _get_config(self, config: Config | None = None) -> Config:
+        """
+        No AWS Operators use the config argument to this method.
+        Keep backward compatibility with other users who might use it
+        """
+        if config is None:
+            config = deepcopy(self.config)
+
+        # ignore[union-attr] is required for this block to appease MyPy
+        # because the user_agent_extra field is generated at runtime.
+        user_agent_config = Config(
+            user_agent_extra=self._generate_user_agent_extra_field(
+                existing_user_agent_extra=config.user_agent_extra  # type: 
ignore[union-attr]
+            )
+        )
+        return config.merge(user_agent_config)  # type: ignore[union-attr]
+
     def get_client_type(
         self,
         region_name: str | None = None,
@@ -458,15 +566,12 @@ class AwsGenericHook(BaseHook, 
Generic[BaseAwsConnection]):
     ) -> boto3.client:
         """Get the underlying boto3 client using boto3 session"""
         client_type = self.client_type
-
-        # No AWS Operators use the config argument to this method.
-        # Keep backward compatibility with other users who might use it
-        if config is None:
-            config = self.config
-
         session = self.get_session(region_name=region_name)
         return session.client(
-            client_type, endpoint_url=self.conn_config.endpoint_url, 
config=config, verify=self.verify
+            client_type,
+            endpoint_url=self.conn_config.endpoint_url,
+            config=self._get_config(config),
+            verify=self.verify,
         )
 
     def get_resource_type(
@@ -476,15 +581,12 @@ class AwsGenericHook(BaseHook, 
Generic[BaseAwsConnection]):
     ) -> boto3.resource:
         """Get the underlying boto3 resource using boto3 session"""
         resource_type = self.resource_type
-
-        # No AWS Operators use the config argument to this method.
-        # Keep backward compatibility with other users who might use it
-        if config is None:
-            config = self.config
-
         session = self.get_session(region_name=region_name)
         return session.resource(
-            resource_type, endpoint_url=self.conn_config.endpoint_url, 
config=config, verify=self.verify
+            resource_type,
+            endpoint_url=self.conn_config.endpoint_url,
+            config=self._get_config(config),
+            verify=self.verify,
         )
 
     @cached_property
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py 
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 3a17b39e10..8db11fbb38 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -22,6 +22,7 @@ import os
 from base64 import b64encode
 from datetime import datetime, timedelta, timezone
 from unittest import mock
+from uuid import UUID
 
 import boto3
 import pytest
@@ -301,6 +302,58 @@ class TestAwsBaseHook:
 
         assert table.item_count == 0
 
+    @pytest.mark.parametrize(
+        "client_meta",
+        [
+            AwsBaseHook(client_type="s3").get_client_type().meta,
+            
AwsBaseHook(resource_type="dynamodb").get_resource_type().meta.client.meta,
+        ],
+    )
+    def test_user_agent_extra_update(self, client_meta):
+        """
+        We are only looking for the keys appended by the AwsBaseHook. A 
user_agent string
+        is a number of key/value pairs such as: `BOTO3/1.25.4 
AIRFLOW/2.5.0.DEV0 AMPP/6.0.0`.
+        """
+        expected_user_agent_tag_keys = ["Airflow", "AmPP", "Caller", 
"DagRunKey"]
+
+        result_user_agent_tags = client_meta.config.user_agent.split(" ")
+        result_user_agent_tag_keys = [tag.split("/")[0].lower() for tag in 
result_user_agent_tags]
+
+        for key in expected_user_agent_tag_keys:
+            assert key.lower() in result_user_agent_tag_keys
+
+    @staticmethod
+    def fetch_tags() -> dict[str:str]:
+        """Helper method which creates an AwsBaseHook and returns the user 
agent string split into a dict."""
+        user_agent_string = 
AwsBaseHook(client_type="s3").get_client_type().meta.config.user_agent
+        # Split the list of {Key}/{Value} into a dict
+        return dict(tag.split("/") for tag in user_agent_string.split(" "))
+
+    @pytest.mark.parametrize("found_classes", [["RandomOperator"], 
["BaseSensorOperator", "TestSensor"]])
+    @mock.patch.object(AwsBaseHook, "_find_class_name")
+    def test_user_agent_caller_target_function_found(self, mock_class_name, 
found_classes):
+        mock_class_name.side_effect = found_classes
+
+        user_agent_tags = self.fetch_tags()
+
+        assert mock_class_name.call_count == len(found_classes)
+        assert user_agent_tags["Caller"] == found_classes[-1]
+
+    def test_user_agent_caller_target_function_not_found(self):
+        default_caller_name = "Unknown"
+
+        user_agent_tags = self.fetch_tags()
+
+        assert user_agent_tags["Caller"] == default_caller_name
+
+    @pytest.mark.parametrize("env_var, expected_version", 
[({"AIRFLOW_CTX_DAG_ID": "banana"}, 5), [{}, None]])
+    @mock.patch.object(AwsBaseHook, "_get_caller", return_value="Test")
+    def test_user_agent_dag_run_key_is_hashed_correctly(self, _, env_var, 
expected_version):
+        with mock.patch.dict(os.environ, env_var, clear=True):
+            dag_run_key = self.fetch_tags()["DagRunKey"]
+
+            assert UUID(dag_run_key).version == expected_version
+
     @mock.patch.object(AwsBaseHook, "get_connection")
     @mock_sts
     def test_assume_role(self, mock_get_connection):
@@ -346,7 +399,7 @@ class TestAwsBaseHook:
             hook.get_client_type("s3")
 
         calls_assume_role = [
-            mock.call.session.Session().client("sts", config=None),
+            mock.call.session.Session().client("sts", config=mock.ANY),
             mock.call.session.Session()
             .client()
             .assume_role(
@@ -510,7 +563,7 @@ class TestAwsBaseHook:
             mock_xpath.assert_called_once_with(xpath)
 
         calls_assume_role_with_saml = [
-            mock.call.session.Session().client("sts", config=None),
+            mock.call.session.Session().client("sts", config=mock.ANY),
             mock.call.session.Session()
             .client()
             .assume_role_with_saml(
@@ -735,7 +788,7 @@ class TestAwsBaseHook:
         mock_session_factory.assert_called_once_with(
             conn=hook.conn_config,
             region_name=method_region_name,
-            config=hook_botocore_config,
+            config=mock.ANY,
         )
         assert mock_session_factory_instance.create_session.assert_called_once
         assert session == MOCK_BOTO3_SESSION
@@ -807,7 +860,7 @@ class ThrowErrorUntilCount:
 
     def __call__(self):
         """
-        Raise an Forbidden until after count threshold has been crossed.
+        Raise an Exception until after count threshold has been crossed.
         Then return True.
         """
         if self.counter < self.count:
@@ -839,7 +892,8 @@ class TestRetryDecorator:  # ptlint: disable=invalid-name
         result = _retryable_test(lambda: 42)
         assert result, 42
 
-    def test_retry_on_exception(self):
+    @mock.patch("time.sleep", return_value=0)
+    def test_retry_on_exception(self, _):
         quota_retry = {
             "stop_after_delay": 2,
             "multiplier": 1,

Reply via email to