This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e7214fd5f4 Reduce `s3hook` memory usage (#37886)
e7214fd5f4 is described below

commit e7214fd5f4bf4a6a0f7eb365a2ebde2346a0bd20
Author: ellisms <[email protected]>
AuthorDate: Wed Mar 6 11:26:07 2024 -0500

    Reduce `s3hook` memory usage (#37886)
---
 airflow/providers/amazon/aws/hooks/s3.py    | 53 +++++++++++++----------------
 tests/providers/amazon/aws/hooks/test_s3.py |  4 +++
 2 files changed, 27 insertions(+), 30 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/s3.py 
b/airflow/providers/amazon/aws/hooks/s3.py
index 28b9adf7c6..580bb8ce79 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -30,16 +30,18 @@ import warnings
 from contextlib import suppress
 from copy import deepcopy
 from datetime import datetime
-from functools import wraps
+from functools import cached_property, wraps
 from inspect import signature
 from io import BytesIO
 from pathlib import Path
 from tempfile import NamedTemporaryFile, gettempdir
-from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
+from typing import TYPE_CHECKING, Any, Callable
 from urllib.parse import urlsplit
 from uuid import uuid4
 
 if TYPE_CHECKING:
+    from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as 
S3ResourceObject
+
     from airflow.utils.types import ArgNotSet
 
     with suppress(ImportError):
@@ -55,22 +57,17 @@ from airflow.providers.amazon.aws.hooks.base_aws import 
AwsBaseHook
 from airflow.providers.amazon.aws.utils.tags import format_tags
 from airflow.utils.helpers import chunks
 
-if TYPE_CHECKING:
-    from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as 
S3ResourceObject
-
-T = TypeVar("T", bound=Callable)
-
 logger = logging.getLogger(__name__)
 
 
-def provide_bucket_name(func: T) -> T:
+def provide_bucket_name(func: Callable) -> Callable:
     """Provide a bucket name taken from the connection if no bucket name has 
been passed to the function."""
     if hasattr(func, "_unify_bucket_name_and_key_wrapped"):
         logger.warning("`unify_bucket_name_and_key` should wrap 
`provide_bucket_name`.")
     function_signature = signature(func)
 
     @wraps(func)
-    def wrapper(*args, **kwargs) -> T:
+    def wrapper(*args, **kwargs) -> Callable:
         bound_args = function_signature.bind(*args, **kwargs)
 
         if "bucket_name" not in bound_args.arguments:
@@ -90,10 +87,10 @@ def provide_bucket_name(func: T) -> T:
 
         return func(*bound_args.args, **bound_args.kwargs)
 
-    return cast(T, wrapper)
+    return wrapper
 
 
-def provide_bucket_name_async(func: T) -> T:
+def provide_bucket_name_async(func: Callable) -> Callable:
     """Provide a bucket name taken from the connection if no bucket name has 
been passed to the function."""
     function_signature = signature(func)
 
@@ -110,15 +107,15 @@ def provide_bucket_name_async(func: T) -> T:
 
         return await func(*bound_args.args, **bound_args.kwargs)
 
-    return cast(T, wrapper)
+    return wrapper
 
 
-def unify_bucket_name_and_key(func: T) -> T:
+def unify_bucket_name_and_key(func: Callable) -> Callable:
     """Unify bucket name and key in case no bucket name and at least a key has 
been passed to the function."""
     function_signature = signature(func)
 
     @wraps(func)
-    def wrapper(*args, **kwargs) -> T:
+    def wrapper(*args, **kwargs) -> Callable:
         bound_args = function_signature.bind(*args, **kwargs)
 
         if "wildcard_key" in bound_args.arguments:
@@ -141,7 +138,7 @@ def unify_bucket_name_and_key(func: T) -> T:
     # if provide_bucket_name is applied first, and there's a bucket defined in 
conn
     # then if user supplies full key, bucket in key is not respected
     wrapper._unify_bucket_name_and_key_wrapped = True  # type: 
ignore[attr-defined]
-    return cast(T, wrapper)
+    return wrapper
 
 
 class S3Hook(AwsBaseHook):
@@ -188,6 +185,15 @@ class S3Hook(AwsBaseHook):
 
         super().__init__(*args, **kwargs)
 
+    @cached_property
+    def resource(self):
+        return self.get_session().resource(
+            self.service_name,
+            
endpoint_url=self.conn_config.get_service_endpoint_url(service_name=self.service_name),
+            config=self.config,
+            verify=self.verify,
+        )
+
     @property
     def extra_args(self):
         """Return hook's extra arguments (immutable)."""
@@ -307,13 +313,7 @@ class S3Hook(AwsBaseHook):
         :param bucket_name: the name of the bucket
         :return: the bucket object to the bucket name.
         """
-        s3_resource = self.get_session().resource(
-            "s3",
-            endpoint_url=self.conn_config.endpoint_url,
-            config=self.config,
-            verify=self.verify,
-        )
-        return s3_resource.Bucket(bucket_name)
+        return self.resource.Bucket(bucket_name)
 
     @provide_bucket_name
     def create_bucket(self, bucket_name: str | None = None, region_name: str | 
None = None) -> None:
@@ -943,14 +943,7 @@ class S3Hook(AwsBaseHook):
                 if arg_name in S3Transfer.ALLOWED_DOWNLOAD_ARGS
             }
 
-        s3_resource = self.get_session().resource(
-            "s3",
-            endpoint_url=self.conn_config.endpoint_url,
-            config=self.config,
-            verify=self.verify,
-        )
-        obj = s3_resource.Object(bucket_name, key)
-
+        obj = self.resource.Object(bucket_name, key)
         obj.load(**sanitize_extra_args())
         return obj
 
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py 
b/tests/providers/amazon/aws/hooks/test_s3.py
index bea5828e67..9ee5e05a9a 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -62,6 +62,10 @@ class TestAwsS3Hook:
         hook = S3Hook()
         assert hook.get_conn() is not None
 
+    def test_resource(self):
+        hook = S3Hook()
+        assert hook.resource is not None
+
     def test_use_threads_default_value(self):
         hook = S3Hook()
         assert hook.transfer_config.use_threads is True

Reply via email to