This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a6315c2f4e Amazon Provider Package user agent (#27823)
a6315c2f4e is described below
commit a6315c2f4ed68c822d0109f9609c1518e0bde94e
Author: D. Ferruzzi <[email protected]>
AuthorDate: Thu Dec 8 10:51:20 2022 -0800
Amazon Provider Package user agent (#27823)
---
airflow/providers/amazon/aws/hooks/base_aws.py | 134 +++++++++++++++++++---
tests/providers/amazon/aws/hooks/test_base_aws.py | 64 ++++++++++-
2 files changed, 177 insertions(+), 21 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py
b/airflow/providers/amazon/aws/hooks/base_aws.py
index ec06eb22e5..6ff5f21467 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -25,9 +25,13 @@ This module contains Base AWS Hook.
from __future__ import annotations
import datetime
+import inspect
import json
import logging
+import os
+import uuid
import warnings
+from copy import deepcopy
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
@@ -47,6 +51,7 @@ from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.providers.amazon.aws.utils.connection_wrapper import
AwsConnectionWrapper
+from airflow.providers_manager import ProvidersManager
from airflow.utils.helpers import exactly_one
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.log.secrets_masker import mask_secret
@@ -409,6 +414,92 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
self._config = config
self._verify = verify
+ @classmethod
+ def _get_provider_version(cls) -> str:
+ """Checks the Providers Manager for the package version."""
+ try:
+ manager = ProvidersManager()
+ hook = manager.hooks[cls.conn_type]
+ if not hook:
+ # This gets caught immediately, but without it MyPy complains
+ # Item "None" of "Optional[HookInfo]" has no attribute
"package_name"
+ # on the following line and static checks fail.
+ raise ValueError(f"Hook info for {cls.conn_type} not found in
the Provider Manager.")
+ provider = manager.providers[hook.package_name]
+ return provider.version
+ except Exception:
+ # Under no condition should an error here ever cause an issue for
the user.
+ return "Unknown"
+
+ @staticmethod
+ def _find_class_name(target_function_name: str) -> str:
+ """
+ Given a frame off the stack, return the name of the class which made
the call.
+ Note: This method may raise a ValueError or an IndexError, but the
calling
+ method is catching and handling those.
+ """
+ stack = inspect.stack()
+ # Find the index of the most recent frame which called the provided
function name.
+ target_frame_index = [frame.function for frame in
stack].index(target_function_name)
+ # Pull that frame off the stack.
+ target_frame = stack[target_frame_index][0]
+ # Get the local variables for that frame.
+ frame_variables = target_frame.f_locals["self"]
+ # Get the class object for that frame.
+ frame_class_object = frame_variables.__class__
+ # Return the name of the class object.
+ return frame_class_object.__name__
+
+ def _get_caller(self, target_function_name: str = "execute") -> str:
+ """Given a function name, walk the stack and return the name of the
class which called it last."""
+ try:
+ caller = self._find_class_name(target_function_name)
+ if caller == "BaseSensorOperator":
+ # If the result is a BaseSensorOperator, then look for
whatever last called "poke".
+ return self._get_caller("poke")
+ return caller
+ except Exception:
+ # Under no condition should an error here ever cause an issue for
the user.
+ return "Unknown"
+
+ @staticmethod
+ def _generate_dag_key() -> str:
+ """
+ The Object Identifier (OID) namespace is used to salt the dag_id value.
+ That salted value is used to generate a SHA-1 hash which, by
definition,
+ can not (reasonably) be reversed. No personal data can be inferred or
+ extracted from the resulting UUID.
+ """
+ try:
+ dag_id = os.environ["AIRFLOW_CTX_DAG_ID"]
+ return str(uuid.uuid5(uuid.NAMESPACE_OID, dag_id))
+ except Exception:
+ # Under no condition should an error here ever cause an issue for
the user.
+ return "00000000-0000-0000-0000-000000000000"
+
+ @staticmethod
+ def _get_airflow_version() -> str:
+ """Fetch and return the current Airflow version."""
+ try:
+ # This can be a circular import under specific configurations.
+ # Importing locally to either avoid or catch it if it does happen.
+ from airflow import __version__ as airflow_version
+
+ return airflow_version
+ except Exception:
+ # Under no condition should an error here ever cause an issue for
the user.
+ return "Unknown"
+
+ def _generate_user_agent_extra_field(self, existing_user_agent_extra: str)
-> str:
+ user_agent_extra_values = [
+ f"Airflow/{self._get_airflow_version()}",
+ f"AmPP/{self._get_provider_version()}",
+ f"Caller/{self._get_caller()}",
+ f"DagRunKey/{self._generate_dag_key()}",
+ existing_user_agent_extra or "",
+ ]
+ return " ".join(user_agent_extra_values).strip()
+
@cached_property
def conn_config(self) -> AwsConnectionWrapper:
"""Get the Airflow Connection object and wrap it in helper (cached)."""
@@ -436,9 +527,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
return self.conn_config.region_name
@property
- def config(self) -> Config | None:
+ def config(self) -> Config:
"""Configuration for botocore client read-only property."""
- return self.conn_config.botocore_config
+ return self.conn_config.botocore_config or botocore.config.Config()
@property
def verify(self) -> bool | str | None:
@@ -451,6 +542,23 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
conn=self.conn_config, region_name=region_name, config=self.config
).create_session()
+ def _get_config(self, config: Config | None = None) -> Config:
+ """
+ No AWS Operators use the config argument to this method.
+ Keep backward compatibility with other users who might use it
+ """
+ if config is None:
+ config = deepcopy(self.config)
+
+ # ignore[union-attr] is required for this block to appease MyPy
+ # because the user_agent_extra field is generated at runtime.
+ user_agent_config = Config(
+ user_agent_extra=self._generate_user_agent_extra_field(
+ existing_user_agent_extra=config.user_agent_extra # type:
ignore[union-attr]
+ )
+ )
+ return config.merge(user_agent_config) # type: ignore[union-attr]
+
def get_client_type(
self,
region_name: str | None = None,
@@ -458,15 +566,12 @@ class AwsGenericHook(BaseHook,
Generic[BaseAwsConnection]):
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session"""
client_type = self.client_type
-
- # No AWS Operators use the config argument to this method.
- # Keep backward compatibility with other users who might use it
- if config is None:
- config = self.config
-
session = self.get_session(region_name=region_name)
return session.client(
- client_type, endpoint_url=self.conn_config.endpoint_url,
config=config, verify=self.verify
+ client_type,
+ endpoint_url=self.conn_config.endpoint_url,
+ config=self._get_config(config),
+ verify=self.verify,
)
def get_resource_type(
@@ -476,15 +581,12 @@ class AwsGenericHook(BaseHook,
Generic[BaseAwsConnection]):
) -> boto3.resource:
"""Get the underlying boto3 resource using boto3 session"""
resource_type = self.resource_type
-
- # No AWS Operators use the config argument to this method.
- # Keep backward compatibility with other users who might use it
- if config is None:
- config = self.config
-
session = self.get_session(region_name=region_name)
return session.resource(
- resource_type, endpoint_url=self.conn_config.endpoint_url,
config=config, verify=self.verify
+ resource_type,
+ endpoint_url=self.conn_config.endpoint_url,
+ config=self._get_config(config),
+ verify=self.verify,
)
@cached_property
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 3a17b39e10..8db11fbb38 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -22,6 +22,7 @@ import os
from base64 import b64encode
from datetime import datetime, timedelta, timezone
from unittest import mock
+from uuid import UUID
import boto3
import pytest
@@ -301,6 +302,58 @@ class TestAwsBaseHook:
assert table.item_count == 0
+ @pytest.mark.parametrize(
+ "client_meta",
+ [
+ AwsBaseHook(client_type="s3").get_client_type().meta,
+
AwsBaseHook(resource_type="dynamodb").get_resource_type().meta.client.meta,
+ ],
+ )
+ def test_user_agent_extra_update(self, client_meta):
+ """
+ We are only looking for the keys appended by the AwsBaseHook. A
user_agent string
+ is a number of key/value pairs such as: `BOTO3/1.25.4
AIRFLOW/2.5.0.DEV0 AMPP/6.0.0`.
+ """
+ expected_user_agent_tag_keys = ["Airflow", "AmPP", "Caller",
"DagRunKey"]
+
+ result_user_agent_tags = client_meta.config.user_agent.split(" ")
+ result_user_agent_tag_keys = [tag.split("/")[0].lower() for tag in
result_user_agent_tags]
+
+ for key in expected_user_agent_tag_keys:
+ assert key.lower() in result_user_agent_tag_keys
+
+ @staticmethod
+ def fetch_tags() -> dict[str:str]:
+ """Helper method which creates an AwsBaseHook and returns the user
agent string split into a dict."""
+ user_agent_string =
AwsBaseHook(client_type="s3").get_client_type().meta.config.user_agent
+ # Split the list of {Key}/{Value} into a dict
+ return dict(tag.split("/") for tag in user_agent_string.split(" "))
+
+ @pytest.mark.parametrize("found_classes", [["RandomOperator"],
["BaseSensorOperator", "TestSensor"]])
+ @mock.patch.object(AwsBaseHook, "_find_class_name")
+ def test_user_agent_caller_target_function_found(self, mock_class_name,
found_classes):
+ mock_class_name.side_effect = found_classes
+
+ user_agent_tags = self.fetch_tags()
+
+ assert mock_class_name.call_count == len(found_classes)
+ assert user_agent_tags["Caller"] == found_classes[-1]
+
+ def test_user_agent_caller_target_function_not_found(self):
+ default_caller_name = "Unknown"
+
+ user_agent_tags = self.fetch_tags()
+
+ assert user_agent_tags["Caller"] == default_caller_name
+
+ @pytest.mark.parametrize("env_var, expected_version",
[({"AIRFLOW_CTX_DAG_ID": "banana"}, 5), [{}, None]])
+ @mock.patch.object(AwsBaseHook, "_get_caller", return_value="Test")
+ def test_user_agent_dag_run_key_is_hashed_correctly(self, _, env_var,
expected_version):
+ with mock.patch.dict(os.environ, env_var, clear=True):
+ dag_run_key = self.fetch_tags()["DagRunKey"]
+
+ assert UUID(dag_run_key).version == expected_version
+
@mock.patch.object(AwsBaseHook, "get_connection")
@mock_sts
def test_assume_role(self, mock_get_connection):
@@ -346,7 +399,7 @@ class TestAwsBaseHook:
hook.get_client_type("s3")
calls_assume_role = [
- mock.call.session.Session().client("sts", config=None),
+ mock.call.session.Session().client("sts", config=mock.ANY),
mock.call.session.Session()
.client()
.assume_role(
@@ -510,7 +563,7 @@ class TestAwsBaseHook:
mock_xpath.assert_called_once_with(xpath)
calls_assume_role_with_saml = [
- mock.call.session.Session().client("sts", config=None),
+ mock.call.session.Session().client("sts", config=mock.ANY),
mock.call.session.Session()
.client()
.assume_role_with_saml(
@@ -735,7 +788,7 @@ class TestAwsBaseHook:
mock_session_factory.assert_called_once_with(
conn=hook.conn_config,
region_name=method_region_name,
- config=hook_botocore_config,
+ config=mock.ANY,
)
assert mock_session_factory_instance.create_session.assert_called_once
assert session == MOCK_BOTO3_SESSION
@@ -807,7 +860,7 @@ class ThrowErrorUntilCount:
def __call__(self):
"""
- Raise an Forbidden until after count threshold has been crossed.
+ Raise an Exception until after count threshold has been crossed.
Then return True.
"""
if self.counter < self.count:
@@ -839,7 +892,8 @@ class TestRetryDecorator: # ptlint: disable=invalid-name
result = _retryable_test(lambda: 42)
assert result, 42
- def test_retry_on_exception(self):
+ @mock.patch("time.sleep", return_value=0)
+ def test_retry_on_exception(self, _):
quota_retry = {
"stop_after_delay": 2,
"multiplier": 1,