This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e7214fd5f4 Reduce `s3hook` memory usage (#37886)
e7214fd5f4 is described below
commit e7214fd5f4bf4a6a0f7eb365a2ebde2346a0bd20
Author: ellisms <[email protected]>
AuthorDate: Wed Mar 6 11:26:07 2024 -0500
Reduce `s3hook` memory usage (#37886)
---
airflow/providers/amazon/aws/hooks/s3.py | 53 +++++++++++++----------------
tests/providers/amazon/aws/hooks/test_s3.py | 4 +++
2 files changed, 27 insertions(+), 30 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/s3.py
b/airflow/providers/amazon/aws/hooks/s3.py
index 28b9adf7c6..580bb8ce79 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -30,16 +30,18 @@ import warnings
from contextlib import suppress
from copy import deepcopy
from datetime import datetime
-from functools import wraps
+from functools import cached_property, wraps
from inspect import signature
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile, gettempdir
-from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
+from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urlsplit
from uuid import uuid4
if TYPE_CHECKING:
+ from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as
S3ResourceObject
+
from airflow.utils.types import ArgNotSet
with suppress(ImportError):
@@ -55,22 +57,17 @@ from airflow.providers.amazon.aws.hooks.base_aws import
AwsBaseHook
from airflow.providers.amazon.aws.utils.tags import format_tags
from airflow.utils.helpers import chunks
-if TYPE_CHECKING:
- from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as
S3ResourceObject
-
-T = TypeVar("T", bound=Callable)
-
logger = logging.getLogger(__name__)
-def provide_bucket_name(func: T) -> T:
+def provide_bucket_name(func: Callable) -> Callable:
"""Provide a bucket name taken from the connection if no bucket name has
been passed to the function."""
if hasattr(func, "_unify_bucket_name_and_key_wrapped"):
logger.warning("`unify_bucket_name_and_key` should wrap
`provide_bucket_name`.")
function_signature = signature(func)
@wraps(func)
- def wrapper(*args, **kwargs) -> T:
+ def wrapper(*args, **kwargs) -> Callable:
bound_args = function_signature.bind(*args, **kwargs)
if "bucket_name" not in bound_args.arguments:
@@ -90,10 +87,10 @@ def provide_bucket_name(func: T) -> T:
return func(*bound_args.args, **bound_args.kwargs)
- return cast(T, wrapper)
+ return wrapper
-def provide_bucket_name_async(func: T) -> T:
+def provide_bucket_name_async(func: Callable) -> Callable:
"""Provide a bucket name taken from the connection if no bucket name has
been passed to the function."""
function_signature = signature(func)
@@ -110,15 +107,15 @@ def provide_bucket_name_async(func: T) -> T:
return await func(*bound_args.args, **bound_args.kwargs)
- return cast(T, wrapper)
+ return wrapper
-def unify_bucket_name_and_key(func: T) -> T:
+def unify_bucket_name_and_key(func: Callable) -> Callable:
"""Unify bucket name and key in case no bucket name and at least a key has
been passed to the function."""
function_signature = signature(func)
@wraps(func)
- def wrapper(*args, **kwargs) -> T:
+ def wrapper(*args, **kwargs) -> Callable:
bound_args = function_signature.bind(*args, **kwargs)
if "wildcard_key" in bound_args.arguments:
@@ -141,7 +138,7 @@ def unify_bucket_name_and_key(func: T) -> T:
# if provide_bucket_name is applied first, and there's a bucket defined in
conn
# then if user supplies full key, bucket in key is not respected
wrapper._unify_bucket_name_and_key_wrapped = True # type:
ignore[attr-defined]
- return cast(T, wrapper)
+ return wrapper
class S3Hook(AwsBaseHook):
@@ -188,6 +185,15 @@ class S3Hook(AwsBaseHook):
super().__init__(*args, **kwargs)
+ @cached_property
+ def resource(self):
+ return self.get_session().resource(
+ self.service_name,
+
endpoint_url=self.conn_config.get_service_endpoint_url(service_name=self.service_name),
+ config=self.config,
+ verify=self.verify,
+ )
+
@property
def extra_args(self):
"""Return hook's extra arguments (immutable)."""
@@ -307,13 +313,7 @@ class S3Hook(AwsBaseHook):
:param bucket_name: the name of the bucket
:return: the bucket object to the bucket name.
"""
- s3_resource = self.get_session().resource(
- "s3",
- endpoint_url=self.conn_config.endpoint_url,
- config=self.config,
- verify=self.verify,
- )
- return s3_resource.Bucket(bucket_name)
+ return self.resource.Bucket(bucket_name)
@provide_bucket_name
def create_bucket(self, bucket_name: str | None = None, region_name: str |
None = None) -> None:
@@ -943,14 +943,7 @@ class S3Hook(AwsBaseHook):
if arg_name in S3Transfer.ALLOWED_DOWNLOAD_ARGS
}
- s3_resource = self.get_session().resource(
- "s3",
- endpoint_url=self.conn_config.endpoint_url,
- config=self.config,
- verify=self.verify,
- )
- obj = s3_resource.Object(bucket_name, key)
-
+ obj = self.resource.Object(bucket_name, key)
obj.load(**sanitize_extra_args())
return obj
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py
b/tests/providers/amazon/aws/hooks/test_s3.py
index bea5828e67..9ee5e05a9a 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -62,6 +62,10 @@ class TestAwsS3Hook:
hook = S3Hook()
assert hook.get_conn() is not None
+ def test_resource(self):
+ hook = S3Hook()
+ assert hook.resource is not None
+
def test_use_threads_default_value(self):
hook = S3Hook()
assert hook.transfer_config.use_threads is True