This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c50541142b9 Base AWS classes - S3 (#47321)
c50541142b9 is described below
commit c50541142b91269c2afc55bd666293500da1102b
Author: Niko Oliveira <[email protected]>
AuthorDate: Wed Mar 5 14:46:02 2025 -0800
Base AWS classes - S3 (#47321)
---
.../airflow/providers/amazon/aws/operators/s3.py | 304 ++++++++++-----------
.../src/airflow/providers/amazon/aws/sensors/s3.py | 73 +++--
.../airflow/providers/amazon/aws/triggers/s3.py | 33 ++-
.../tests/unit/amazon/aws/operators/test_s3.py | 26 +-
.../tests/unit/amazon/aws/sensors/test_s3.py | 6 +-
.../tests/unit/amazon/aws/triggers/test_s3.py | 5 +
.../unit/google/cloud/transfers/test_s3_to_gcs.py | 46 ++--
7 files changed, 248 insertions(+), 245 deletions(-)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py
index 05c5bd88634..406d9f59752 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py
@@ -29,8 +29,9 @@ import pytz
from dateutil import parser
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.utils.helpers import exactly_one
if TYPE_CHECKING:
@@ -41,7 +42,7 @@ if TYPE_CHECKING:
BUCKET_DOES_NOT_EXIST_MSG = "Bucket with name: %s doesn't exist"
-class S3CreateBucketOperator(BaseOperator):
+class S3CreateBucketOperator(AwsBaseOperator[S3Hook]):
"""
This operator creates an S3 bucket.
@@ -51,38 +52,38 @@ class S3CreateBucketOperator(BaseOperator):
:param bucket_name: This is bucket name you want to create
:param aws_conn_id: The Airflow connection used for AWS credentials.
- If this is None or empty then the default boto3 behaviour is used. If
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
- :param region_name: AWS region_name. If not specified fetched from
connection.
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
- template_fields: Sequence[str] = ("bucket_name",)
+ template_fields: Sequence[str] = aws_template_fields("bucket_name")
+ aws_hook_class = S3Hook
def __init__(
self,
*,
bucket_name: str,
- aws_conn_id: str | None = "aws_default",
- region_name: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.bucket_name = bucket_name
- self.region_name = region_name
- self.aws_conn_id = aws_conn_id
def execute(self, context: Context):
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id,
region_name=self.region_name)
- if not s3_hook.check_for_bucket(self.bucket_name):
- s3_hook.create_bucket(bucket_name=self.bucket_name,
region_name=self.region_name)
+ if not self.hook.check_for_bucket(self.bucket_name):
+ self.hook.create_bucket(bucket_name=self.bucket_name,
region_name=self.region_name)
self.log.info("Created bucket with name: %s", self.bucket_name)
else:
self.log.info("Bucket with name: %s already exists",
self.bucket_name)
-class S3DeleteBucketOperator(BaseOperator):
+class S3DeleteBucketOperator(AwsBaseOperator[S3Hook]):
"""
This operator deletes an S3 bucket.
@@ -93,36 +94,39 @@ class S3DeleteBucketOperator(BaseOperator):
:param bucket_name: This is bucket name you want to delete
:param force_delete: Forcibly delete all objects in the bucket before
deleting the bucket
:param aws_conn_id: The Airflow connection used for AWS credentials.
- If this is None or empty then the default boto3 behaviour is used. If
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
- template_fields: Sequence[str] = ("bucket_name",)
+ template_fields: Sequence[str] = aws_template_fields("bucket_name")
+ aws_hook_class = S3Hook
def __init__(
self,
bucket_name: str,
force_delete: bool = False,
- aws_conn_id: str | None = "aws_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
self.bucket_name = bucket_name
self.force_delete = force_delete
- self.aws_conn_id = aws_conn_id
def execute(self, context: Context):
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
- if s3_hook.check_for_bucket(self.bucket_name):
- s3_hook.delete_bucket(bucket_name=self.bucket_name,
force_delete=self.force_delete)
+ if self.hook.check_for_bucket(self.bucket_name):
+ self.hook.delete_bucket(bucket_name=self.bucket_name,
force_delete=self.force_delete)
self.log.info("Deleted bucket with name: %s", self.bucket_name)
else:
self.log.info("Bucket with name: %s doesn't exist",
self.bucket_name)
-class S3GetBucketTaggingOperator(BaseOperator):
+class S3GetBucketTaggingOperator(AwsBaseOperator[S3Hook]):
"""
This operator gets tagging from an S3 bucket.
@@ -132,31 +136,34 @@ class S3GetBucketTaggingOperator(BaseOperator):
:param bucket_name: This is bucket name you want to reference
:param aws_conn_id: The Airflow connection used for AWS credentials.
- If this is None or empty then the default boto3 behaviour is used. If
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
- template_fields: Sequence[str] = ("bucket_name",)
+ template_fields: Sequence[str] = aws_template_fields("bucket_name")
+ aws_hook_class = S3Hook
def __init__(self, bucket_name: str, aws_conn_id: str | None =
"aws_default", **kwargs) -> None:
super().__init__(**kwargs)
self.bucket_name = bucket_name
- self.aws_conn_id = aws_conn_id
def execute(self, context: Context):
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
-
- if s3_hook.check_for_bucket(self.bucket_name):
+ if self.hook.check_for_bucket(self.bucket_name):
self.log.info("Getting tags for bucket %s", self.bucket_name)
- return s3_hook.get_bucket_tagging(self.bucket_name)
+ return self.hook.get_bucket_tagging(self.bucket_name)
else:
self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name)
return None
-class S3PutBucketTaggingOperator(BaseOperator):
+class S3PutBucketTaggingOperator(AwsBaseOperator[S3Hook]):
"""
This operator puts tagging for an S3 bucket.
@@ -171,14 +178,20 @@ class S3PutBucketTaggingOperator(BaseOperator):
If a value is provided, a key must be provided as well.
:param tag_set: A dictionary containing the tags, or a List of key/value
pairs.
:param aws_conn_id: The Airflow connection used for AWS credentials.
- If this is None or empty then the default boto3 behaviour is used. If
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
running Airflow in a distributed manner and aws_conn_id is None or
- empty, then the default boto3 configuration would be used (and must be
+ empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
- template_fields: Sequence[str] = ("bucket_name",)
+ template_fields: Sequence[str] = aws_template_fields("bucket_name")
template_fields_renderers = {"tag_set": "json"}
+ aws_hook_class = S3Hook
def __init__(
self,
@@ -186,7 +199,6 @@ class S3PutBucketTaggingOperator(BaseOperator):
key: str | None = None,
value: str | None = None,
tag_set: dict | list[dict[str, str]] | None = None,
- aws_conn_id: str | None = "aws_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -194,14 +206,11 @@ class S3PutBucketTaggingOperator(BaseOperator):
self.value = value
self.tag_set = tag_set
self.bucket_name = bucket_name
- self.aws_conn_id = aws_conn_id
def execute(self, context: Context):
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
-
- if s3_hook.check_for_bucket(self.bucket_name):
+ if self.hook.check_for_bucket(self.bucket_name):
self.log.info("Putting tags for bucket %s", self.bucket_name)
- return s3_hook.put_bucket_tagging(
+ return self.hook.put_bucket_tagging(
key=self.key, value=self.value, tag_set=self.tag_set,
bucket_name=self.bucket_name
)
else:
@@ -209,7 +218,7 @@ class S3PutBucketTaggingOperator(BaseOperator):
return None
-class S3DeleteBucketTaggingOperator(BaseOperator):
+class S3DeleteBucketTaggingOperator(AwsBaseOperator[S3Hook]):
"""
This operator deletes tagging from an S3 bucket.
@@ -219,31 +228,38 @@ class S3DeleteBucketTaggingOperator(BaseOperator):
:param bucket_name: This is the name of the bucket to delete tags from.
:param aws_conn_id: The Airflow connection used for AWS credentials.
- If this is None or empty then the default boto3 behaviour is used. If
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
- template_fields: Sequence[str] = ("bucket_name",)
+ template_fields: Sequence[str] = aws_template_fields("bucket_name")
+ aws_hook_class = S3Hook
- def __init__(self, bucket_name: str, aws_conn_id: str | None =
"aws_default", **kwargs) -> None:
+ def __init__(
+ self,
+ bucket_name: str,
+ **kwargs,
+ ) -> None:
super().__init__(**kwargs)
self.bucket_name = bucket_name
- self.aws_conn_id = aws_conn_id
def execute(self, context: Context):
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
-
- if s3_hook.check_for_bucket(self.bucket_name):
+ if self.hook.check_for_bucket(self.bucket_name):
self.log.info("Deleting tags for bucket %s", self.bucket_name)
- return s3_hook.delete_bucket_tagging(self.bucket_name)
+ return self.hook.delete_bucket_tagging(self.bucket_name)
else:
self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name)
return None
-class S3CopyObjectOperator(BaseOperator):
+class S3CopyObjectOperator(AwsBaseOperator[S3Hook]):
"""
Creates a copy of an object that is already stored in S3.
@@ -269,30 +285,29 @@ class S3CopyObjectOperator(BaseOperator):
It should be omitted when `dest_bucket_key` is provided as a full
s3:// url.
:param source_version_id: Version ID of the source object (OPTIONAL)
- :param aws_conn_id: Connection id of the S3 connection to use
- :param verify: Whether or not to verify SSL certificates for S3 connection.
- By default SSL certificates are verified.
-
- You can provide the following values:
-
- - False: do not validate SSL certificates. SSL will still be used,
- but SSL certificates will not be
- verified.
- - path/to/cert/bundle.pem: A filename of the CA cert bundle to uses.
- You can specify this argument if you want to use a different
- CA cert bundle than the one used by botocore.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
:param acl_policy: String specifying the canned ACL policy for the file
being
uploaded to the S3 bucket.
:param meta_data_directive: Whether to `COPY` the metadata from the source
object or `REPLACE` it with
metadata that's provided in the request.
"""
- template_fields: Sequence[str] = (
+ template_fields: Sequence[str] = aws_template_fields(
"source_bucket_key",
"dest_bucket_key",
"source_bucket_name",
"dest_bucket_name",
)
+ aws_hook_class = S3Hook
def __init__(
self,
@@ -302,8 +317,6 @@ class S3CopyObjectOperator(BaseOperator):
source_bucket_name: str | None = None,
dest_bucket_name: str | None = None,
source_version_id: str | None = None,
- aws_conn_id: str | None = "aws_default",
- verify: str | bool | None = None,
acl_policy: str | None = None,
meta_data_directive: str | None = None,
**kwargs,
@@ -315,14 +328,11 @@ class S3CopyObjectOperator(BaseOperator):
self.source_bucket_name = source_bucket_name
self.dest_bucket_name = dest_bucket_name
self.source_version_id = source_version_id
- self.aws_conn_id = aws_conn_id
- self.verify = verify
self.acl_policy = acl_policy
self.meta_data_directive = meta_data_directive
def execute(self, context: Context):
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
- s3_hook.copy_object(
+ self.hook.copy_object(
self.source_bucket_key,
self.dest_bucket_key,
self.source_bucket_name,
@@ -336,11 +346,11 @@ class S3CopyObjectOperator(BaseOperator):
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.openlineage.extractors import OperatorLineage
- dest_bucket_name, dest_bucket_key = S3Hook.get_s3_bucket_key(
+ dest_bucket_name, dest_bucket_key = self.hook.get_s3_bucket_key(
self.dest_bucket_name, self.dest_bucket_key, "dest_bucket_name",
"dest_bucket_key"
)
- source_bucket_name, source_bucket_key = S3Hook.get_s3_bucket_key(
+ source_bucket_name, source_bucket_key = self.hook.get_s3_bucket_key(
self.source_bucket_name, self.source_bucket_key,
"source_bucket_name", "source_bucket_key"
)
@@ -359,7 +369,7 @@ class S3CopyObjectOperator(BaseOperator):
)
-class S3CreateObjectOperator(BaseOperator):
+class S3CreateObjectOperator(AwsBaseOperator[S3Hook]):
"""
Creates a new object from `data` as string or bytes.
@@ -382,22 +392,21 @@ class S3CreateObjectOperator(BaseOperator):
It should be specified only when `data` is provided as string.
:param compression: Type of compression to use, currently only gzip is
supported.
It can be specified only when `data` is provided as string.
- :param aws_conn_id: Connection id of the S3 connection to use
- :param verify: Whether or not to verify SSL certificates for S3 connection.
- By default SSL certificates are verified.
-
- You can provide the following values:
-
- - False: do not validate SSL certificates. SSL will still be used,
- but SSL certificates will not be
- verified.
- - path/to/cert/bundle.pem: A filename of the CA cert bundle to uses.
- You can specify this argument if you want to use a different
- CA cert bundle than the one used by botocore.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
- template_fields: Sequence[str] = ("s3_bucket", "s3_key", "data")
+ template_fields: Sequence[str] = aws_template_fields("s3_bucket",
"s3_key", "data")
+ aws_hook_class = S3Hook
def __init__(
self,
@@ -410,8 +419,6 @@ class S3CreateObjectOperator(BaseOperator):
acl_policy: str | None = None,
encoding: str | None = None,
compression: str | None = None,
- aws_conn_id: str | None = "aws_default",
- verify: str | bool | None = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -424,16 +431,14 @@ class S3CreateObjectOperator(BaseOperator):
self.acl_policy = acl_policy
self.encoding = encoding
self.compression = compression
- self.aws_conn_id = aws_conn_id
- self.verify = verify
def execute(self, context: Context):
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
-
- s3_bucket, s3_key = s3_hook.get_s3_bucket_key(self.s3_bucket,
self.s3_key, "dest_bucket", "dest_key")
+ s3_bucket, s3_key = self.hook.get_s3_bucket_key(
+ self.s3_bucket, self.s3_key, "dest_bucket", "dest_key"
+ )
if isinstance(self.data, str):
- s3_hook.load_string(
+ self.hook.load_string(
self.data,
s3_key,
s3_bucket,
@@ -444,13 +449,13 @@ class S3CreateObjectOperator(BaseOperator):
self.compression,
)
else:
- s3_hook.load_bytes(self.data, s3_key, s3_bucket, self.replace,
self.encrypt, self.acl_policy)
+ self.hook.load_bytes(self.data, s3_key, s3_bucket, self.replace,
self.encrypt, self.acl_policy)
def get_openlineage_facets_on_start(self):
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.openlineage.extractors import OperatorLineage
- bucket, key = S3Hook.get_s3_bucket_key(self.s3_bucket, self.s3_key,
"dest_bucket", "dest_key")
+ bucket, key = self.hook.get_s3_bucket_key(self.s3_bucket, self.s3_key,
"dest_bucket", "dest_key")
output_dataset = Dataset(
namespace=f"s3://{bucket}",
@@ -462,7 +467,7 @@ class S3CreateObjectOperator(BaseOperator):
)
-class S3DeleteObjectsOperator(BaseOperator):
+class S3DeleteObjectsOperator(AwsBaseOperator[S3Hook]):
"""
To enable users to delete single object or multiple objects from a bucket
using a single HTTP request.
@@ -485,21 +490,22 @@ class S3DeleteObjectsOperator(BaseOperator):
All objects which LastModified Date is greater than this datetime in
the bucket will be deleted.
:param to_datetime: less LastModified Date of objects to delete.
(templated)
All objects which LastModified Date is less than this datetime in the
bucket will be deleted.
- :param aws_conn_id: Connection id of the S3 connection to use
- :param verify: Whether or not to verify SSL certificates for S3 connection.
- By default SSL certificates are verified.
-
- You can provide the following values:
-
- - ``False``: do not validate SSL certificates. SSL will still be used,
- but SSL certificates will not be
- verified.
- - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to
uses.
- You can specify this argument if you want to use a different
- CA cert bundle than the one used by botocore.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
- template_fields: Sequence[str] = ("keys", "bucket", "prefix",
"from_datetime", "to_datetime")
+ template_fields: Sequence[str] = aws_template_fields(
+ "keys", "bucket", "prefix", "from_datetime", "to_datetime"
+ )
+ aws_hook_class = S3Hook
def __init__(
self,
@@ -509,8 +515,6 @@ class S3DeleteObjectsOperator(BaseOperator):
prefix: str | None = None,
from_datetime: datetime | str | None = None,
to_datetime: datetime | str | None = None,
- aws_conn_id: str | None = "aws_default",
- verify: str | bool | None = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -519,8 +523,6 @@ class S3DeleteObjectsOperator(BaseOperator):
self.prefix = prefix
self.from_datetime = from_datetime
self.to_datetime = to_datetime
- self.aws_conn_id = aws_conn_id
- self.verify = verify
self._keys: str | list[str] = ""
@@ -546,16 +548,14 @@ class S3DeleteObjectsOperator(BaseOperator):
if isinstance(self.from_datetime, str):
self.from_datetime =
parser.parse(self.from_datetime).replace(tzinfo=pytz.UTC)
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
-
- keys = self.keys or s3_hook.list_keys(
+ keys = self.keys or self.hook.list_keys(
bucket_name=self.bucket,
prefix=self.prefix,
from_datetime=self.from_datetime,
to_datetime=self.to_datetime,
)
if keys:
- s3_hook.delete_objects(bucket=self.bucket, keys=keys)
+ self.hook.delete_objects(bucket=self.bucket, keys=keys)
self._keys = keys
def get_openlineage_facets_on_complete(self, task_instance):
@@ -598,7 +598,7 @@ class S3DeleteObjectsOperator(BaseOperator):
)
-class S3FileTransformOperator(BaseOperator):
+class S3FileTransformOperator(AwsBaseOperator[S3Hook]):
"""
Copies data from a source S3 location to a temporary location on the local
filesystem.
@@ -644,9 +644,10 @@ class S3FileTransformOperator(BaseOperator):
:param replace: Replace dest S3 key if it already exists
"""
- template_fields: Sequence[str] = ("source_s3_key", "dest_s3_key",
"script_args")
+ template_fields: Sequence[str] = aws_template_fields("source_s3_key",
"dest_s3_key", "script_args")
template_ext: Sequence[str] = ()
ui_color = "#f9c915"
+ aws_hook_class = S3Hook
def __init__(
self,
@@ -682,6 +683,7 @@ class S3FileTransformOperator(BaseOperator):
if self.transform_script is None and self.select_expression is None:
raise AirflowException("Either transform_script or
select_expression must be specified")
+ # Keep these hooks constructed here since we are using two unique
conn_ids
source_s3 = S3Hook(aws_conn_id=self.source_aws_conn_id,
verify=self.source_verify)
dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id,
verify=self.dest_verify)
@@ -770,7 +772,7 @@ class S3FileTransformOperator(BaseOperator):
)
-class S3ListOperator(BaseOperator):
+class S3ListOperator(AwsBaseOperator[S3Hook]):
"""
List all objects from the bucket with the given string prefix in name.
@@ -785,17 +787,16 @@ class S3ListOperator(BaseOperator):
:param prefix: Prefix string to filters the objects whose name begin with
such prefix. (templated)
:param delimiter: the delimiter marks key hierarchy. (templated)
- :param aws_conn_id: The connection ID to use when connecting to S3 storage.
- :param verify: Whether or not to verify SSL certificates for S3 connection.
- By default SSL certificates are verified.
- You can provide the following values:
-
- - ``False``: do not validate SSL certificates. SSL will still be used
- (unless use_ssl is False), but SSL certificates will not be
- verified.
- - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to
uses.
- You can specify this argument if you want to use a different
- CA cert bundle than the one used by botocore.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
:param apply_wildcard: whether to treat '*' as a wildcard or a plain
symbol in the prefix.
@@ -813,8 +814,9 @@ class S3ListOperator(BaseOperator):
)
"""
- template_fields: Sequence[str] = ("bucket", "prefix", "delimiter")
+ template_fields: Sequence[str] = aws_template_fields("bucket", "prefix",
"delimiter")
ui_color = "#ffd700"
+ aws_hook_class = S3Hook
def __init__(
self,
@@ -822,8 +824,6 @@ class S3ListOperator(BaseOperator):
bucket: str,
prefix: str = "",
delimiter: str = "",
- aws_conn_id: str | None = "aws_default",
- verify: str | bool | None = None,
apply_wildcard: bool = False,
**kwargs,
):
@@ -831,13 +831,9 @@ class S3ListOperator(BaseOperator):
self.bucket = bucket
self.prefix = prefix
self.delimiter = delimiter
- self.aws_conn_id = aws_conn_id
- self.verify = verify
self.apply_wildcard = apply_wildcard
def execute(self, context: Context):
- hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
-
self.log.info(
"Getting the list of files from bucket: %s in prefix: %s
(Delimiter %s)",
self.bucket,
@@ -845,7 +841,7 @@ class S3ListOperator(BaseOperator):
self.delimiter,
)
- return hook.list_keys(
+ return self.hook.list_keys(
bucket_name=self.bucket,
prefix=self.prefix,
delimiter=self.delimiter,
@@ -853,7 +849,7 @@ class S3ListOperator(BaseOperator):
)
-class S3ListPrefixesOperator(BaseOperator):
+class S3ListPrefixesOperator(AwsBaseOperator[S3Hook]):
"""
List all subfolders from the bucket with the given string prefix in name.
@@ -868,17 +864,16 @@ class S3ListPrefixesOperator(BaseOperator):
:param prefix: Prefix string to filter the subfolders whose name begin with
such prefix. (templated)
:param delimiter: the delimiter marks subfolder hierarchy. (templated)
- :param aws_conn_id: The connection ID to use when connecting to S3 storage.
- :param verify: Whether or not to verify SSL certificates for S3 connection.
- By default SSL certificates are verified.
- You can provide the following values:
-
- - ``False``: do not validate SSL certificates. SSL will still be used
- (unless use_ssl is False), but SSL certificates will not be
- verified.
- - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to
uses.
- You can specify this argument if you want to use a different
- CA cert bundle than the one used by botocore.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
**Example**:
@@ -894,8 +889,9 @@ class S3ListPrefixesOperator(BaseOperator):
)
"""
- template_fields: Sequence[str] = ("bucket", "prefix", "delimiter")
+ template_fields: Sequence[str] = aws_template_fields("bucket", "prefix",
"delimiter")
ui_color = "#ffd700"
+ aws_hook_class = S3Hook
def __init__(
self,
@@ -903,20 +899,14 @@ class S3ListPrefixesOperator(BaseOperator):
bucket: str,
prefix: str,
delimiter: str,
- aws_conn_id: str | None = "aws_default",
- verify: str | bool | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.bucket = bucket
self.prefix = prefix
self.delimiter = delimiter
- self.aws_conn_id = aws_conn_id
- self.verify = verify
def execute(self, context: Context):
- hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
-
self.log.info(
"Getting the list of subfolders from bucket: %s in prefix: %s
(Delimiter %s)",
self.bucket,
@@ -924,4 +914,4 @@ class S3ListPrefixesOperator(BaseOperator):
self.delimiter,
)
- return hook.list_prefixes(bucket_name=self.bucket, prefix=self.prefix,
delimiter=self.delimiter)
+ return self.hook.list_prefixes(bucket_name=self.bucket,
prefix=self.prefix, delimiter=self.delimiter)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py
b/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py
index bb3616597d2..c59bf53ea1c 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py
@@ -23,7 +23,6 @@ import os
import re
from collections.abc import Sequence
from datetime import datetime, timedelta
-from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable, cast
from airflow.configuration import conf
@@ -34,11 +33,13 @@ if TYPE_CHECKING:
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.triggers.s3 import S3KeysUnchangedTrigger,
S3KeyTrigger
-from airflow.sensors.base import BaseSensorOperator, poke_mode_only
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
+from airflow.sensors.base import poke_mode_only
-class S3KeySensor(BaseSensorOperator):
+class S3KeySensor(AwsBaseSensor[S3Hook]):
"""
Waits for one or multiple keys (a file-like instance on S3) to be present
in a S3 bucket.
@@ -65,17 +66,6 @@ class S3KeySensor(BaseSensorOperator):
def check_fn(files: List, **kwargs) -> bool:
return any(f.get('Size', 0) > 1048576 for f in files)
- :param aws_conn_id: a reference to the s3 connection
- :param verify: Whether to verify SSL certificates for S3 connection.
- By default, SSL certificates are verified.
- You can provide the following values:
-
- - ``False``: do not validate SSL certificates. SSL will still be used
- (unless use_ssl is False), but SSL certificates will not be
- verified.
- - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to
uses.
- You can specify this argument if you want to use a different
- CA cert bundle than the one used by botocore.
:param deferrable: Run operator in the deferrable mode
:param use_regex: whether to use regex to check bucket
:param metadata_keys: List of head_object attributes to gather and send to
``check_fn``.
@@ -83,9 +73,18 @@ class S3KeySensor(BaseSensorOperator):
all available attributes.
Default value: "Size".
If the requested attribute is not found, the key is still included and
the value is None.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
"""
- template_fields: Sequence[str] = ("bucket_key", "bucket_name")
+ template_fields: Sequence[str] = aws_template_fields("bucket_key",
"bucket_name")
+ aws_hook_class = S3Hook
def __init__(
self,
@@ -94,7 +93,6 @@ class S3KeySensor(BaseSensorOperator):
bucket_name: str | None = None,
wildcard_match: bool = False,
check_fn: Callable[..., bool] | None = None,
- aws_conn_id: str | None = "aws_default",
verify: str | bool | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
use_regex: bool = False,
@@ -106,14 +104,13 @@ class S3KeySensor(BaseSensorOperator):
self.bucket_key = bucket_key
self.wildcard_match = wildcard_match
self.check_fn = check_fn
- self.aws_conn_id = aws_conn_id
self.verify = verify
self.deferrable = deferrable
self.use_regex = use_regex
self.metadata_keys = metadata_keys if metadata_keys else ["Size"]
def _check_key(self, key, context: Context):
- bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key,
"bucket_name", "bucket_key")
+ bucket_name, key = self.hook.get_s3_bucket_key(self.bucket_name, key,
"bucket_name", "bucket_key")
self.log.info("Poking for key : s3://%s/%s", bucket_name, key)
"""
@@ -199,7 +196,9 @@ class S3KeySensor(BaseSensorOperator):
bucket_key=self.bucket_key,
wildcard_match=self.wildcard_match,
aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
verify=self.verify,
+ botocore_config=self.botocore_config,
poke_interval=self.poke_interval,
should_check_fn=bool(self.check_fn),
use_regex=self.use_regex,
@@ -220,13 +219,9 @@ class S3KeySensor(BaseSensorOperator):
elif event["status"] == "error":
raise AirflowException(event["message"])
- @cached_property
- def hook(self) -> S3Hook:
- return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
-
@poke_mode_only
-class S3KeysUnchangedSensor(BaseSensorOperator):
+class S3KeysUnchangedSensor(AwsBaseSensor[S3Hook]):
"""
Return True if inactivity_period has passed with no increase in the number
of objects matching prefix.
@@ -239,17 +234,7 @@ class S3KeysUnchangedSensor(BaseSensorOperator):
:param bucket_name: Name of the S3 bucket
:param prefix: The prefix being waited on. Relative path from bucket root
level.
- :param aws_conn_id: a reference to the s3 connection
- :param verify: Whether or not to verify SSL certificates for S3 connection.
- By default SSL certificates are verified.
- You can provide the following values:
-
- - ``False``: do not validate SSL certificates. SSL will still be used
- (unless use_ssl is False), but SSL certificates will not be
- verified.
- - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to
uses.
- You can specify this argument if you want to use a different
- CA cert bundle than the one used by botocore.
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param inactivity_period: The total seconds of inactivity to designate
keys unchanged. Note, this mechanism is not real time and
this operator may not return until a poke_interval after this period
@@ -261,16 +246,24 @@ class S3KeysUnchangedSensor(BaseSensorOperator):
between pokes valid behavior. If true a warning message will be logged
when this happens. If false an error will be raised.
:param deferrable: Run sensor in the deferrable mode
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
"""
- template_fields: Sequence[str] = ("bucket_name", "prefix")
+ template_fields: Sequence[str] = aws_template_fields("bucket_name",
"prefix")
+ aws_hook_class = S3Hook
def __init__(
self,
*,
bucket_name: str,
prefix: str,
- aws_conn_id: str | None = "aws_default",
verify: bool | str | None = None,
inactivity_period: float = 60 * 60,
min_objects: int = 1,
@@ -291,15 +284,9 @@ class S3KeysUnchangedSensor(BaseSensorOperator):
self.inactivity_seconds = 0
self.allow_delete = allow_delete
self.deferrable = deferrable
- self.aws_conn_id = aws_conn_id
self.verify = verify
self.last_activity_time: datetime | None = None
- @cached_property
- def hook(self):
- """Returns S3Hook."""
- return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
-
def is_keys_unchanged(self, current_objects: set[str]) -> bool:
"""
Check for new objects after the inactivity_period and update the
sensor state accordingly.
@@ -382,7 +369,9 @@ class S3KeysUnchangedSensor(BaseSensorOperator):
inactivity_seconds=self.inactivity_seconds,
allow_delete=self.allow_delete,
aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
verify=self.verify,
+ botocore_config=self.botocore_config,
last_activity_time=self.last_activity_time,
),
method_name="execute_complete",
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py
index 0be6c992cc8..9d2b055fe44 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py
@@ -53,6 +53,9 @@ class S3KeyTrigger(BaseTrigger):
poke_interval: float = 5.0,
should_check_fn: bool = False,
use_regex: bool = False,
+ region_name: str | None = None,
+ verify: bool | str | None = None,
+ botocore_config: dict | None = None,
**hook_params: Any,
):
super().__init__()
@@ -64,6 +67,9 @@ class S3KeyTrigger(BaseTrigger):
self.poke_interval = poke_interval
self.should_check_fn = should_check_fn
self.use_regex = use_regex
+ self.region_name = region_name
+ self.verify = verify
+ self.botocore_config = botocore_config
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize S3KeyTrigger arguments and classpath."""
@@ -78,12 +84,20 @@ class S3KeyTrigger(BaseTrigger):
"poke_interval": self.poke_interval,
"should_check_fn": self.should_check_fn,
"use_regex": self.use_regex,
+ "region_name": self.region_name,
+ "verify": self.verify,
+ "botocore_config": self.botocore_config,
},
)
@cached_property
def hook(self) -> S3Hook:
- return S3Hook(aws_conn_id=self.aws_conn_id,
verify=self.hook_params.get("verify"))
+ return S3Hook(
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ verify=self.verify,
+ config=self.botocore_config,
+ )
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make an asynchronous connection using S3HookAsync."""
@@ -143,7 +157,9 @@ class S3KeysUnchangedTrigger(BaseTrigger):
allow_delete: bool = True,
aws_conn_id: str | None = "aws_default",
last_activity_time: datetime | None = None,
+ region_name: str | None = None,
verify: bool | str | None = None,
+ botocore_config: dict | None = None,
**hook_params: Any,
):
super().__init__()
@@ -160,8 +176,10 @@ class S3KeysUnchangedTrigger(BaseTrigger):
self.allow_delete = allow_delete
self.aws_conn_id = aws_conn_id
self.last_activity_time = last_activity_time
- self.verify = verify
self.polling_period_seconds = 0
+ self.region_name = region_name
+ self.verify = verify
+ self.botocore_config = botocore_config
self.hook_params = hook_params
def serialize(self) -> tuple[str, dict[str, Any]]:
@@ -179,14 +197,21 @@ class S3KeysUnchangedTrigger(BaseTrigger):
"aws_conn_id": self.aws_conn_id,
"last_activity_time": self.last_activity_time,
"hook_params": self.hook_params,
- "verify": self.verify,
"polling_period_seconds": self.polling_period_seconds,
+ "region_name": self.region_name,
+ "verify": self.verify,
+ "botocore_config": self.botocore_config,
},
)
@cached_property
def hook(self) -> S3Hook:
- return S3Hook(aws_conn_id=self.aws_conn_id,
verify=self.hook_params.get("verify"))
+ return S3Hook(
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ verify=self.verify,
+ config=self.botocore_config,
+ )
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make an asynchronous connection using S3Hook."""
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py
index 0195afd6e2a..b17d4374c44 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py
@@ -415,20 +415,19 @@ class TestS3FileTransformOperator:
class TestS3ListOperator:
- @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
- def test_execute(self, mock_hook):
- mock_hook.return_value.list_keys.return_value = ["TEST1.csv",
"TEST2.csv", "TEST3.csv"]
-
+ def test_execute(self):
operator = S3ListOperator(
task_id="test-s3-list-operator",
bucket=BUCKET_NAME,
prefix="TEST",
delimiter=".csv",
)
+ operator.hook = mock.MagicMock()
+ operator.hook.list_keys.return_value = ["TEST1.csv", "TEST2.csv",
"TEST3.csv"]
files = operator.execute(None)
- mock_hook.return_value.list_keys.assert_called_once_with(
+ operator.hook.list_keys.assert_called_once_with(
bucket_name=BUCKET_NAME,
prefix="TEST",
delimiter=".csv",
@@ -447,17 +446,16 @@ class TestS3ListOperator:
class TestS3ListPrefixesOperator:
- @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
- def test_execute(self, mock_hook):
- mock_hook.return_value.list_prefixes.return_value = ["test/"]
-
+ def test_execute(self):
operator = S3ListPrefixesOperator(
task_id="test-s3-list-prefixes-operator", bucket=BUCKET_NAME,
prefix="test/", delimiter="/"
)
+ operator.hook = mock.MagicMock()
+ operator.hook.list_prefixes.return_value = ["test/"]
subfolders = operator.execute(None)
- mock_hook.return_value.list_prefixes.assert_called_once_with(
+ operator.hook.list_prefixes.assert_called_once_with(
bucket_name=BUCKET_NAME, prefix="test/", delimiter="/"
)
assert subfolders == ["test/"]
@@ -870,8 +868,7 @@ class TestS3DeleteObjectsOperator:
assert objects_in_dest_bucket["Contents"][0]["Key"] == key_of_test
@pytest.mark.parametrize("keys", ("path/data.txt", ["path/data.txt"]))
- @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
- def test_get_openlineage_facets_on_complete_single_object(self, mock_hook,
keys):
+ def test_get_openlineage_facets_on_complete_single_object(self, keys):
bucket = "testbucket"
expected_input = Dataset(
namespace=f"s3://{bucket}",
@@ -888,14 +885,14 @@ class TestS3DeleteObjectsOperator:
)
op =
S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object",
bucket=bucket, keys=keys)
+ op.hook = mock.MagicMock()
op.execute(None)
lineage = op.get_openlineage_facets_on_complete(None)
assert len(lineage.inputs) == 1
assert lineage.inputs[0] == expected_input
- @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
- def test_get_openlineage_facets_on_complete_multiple_objects(self,
mock_hook):
+ def test_get_openlineage_facets_on_complete_multiple_objects(self):
bucket = "testbucket"
keys = ["path/data1.txt", "path/data2.txt"]
expected_inputs = [
@@ -928,6 +925,7 @@ class TestS3DeleteObjectsOperator:
]
op =
S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object",
bucket=bucket, keys=keys)
+ op.hook = mock.MagicMock()
op.execute(None)
lineage = op.get_openlineage_facets_on_complete(None)
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py
b/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py
index 9c169f86cf0..b9f8fb284bc 100644
--- a/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py
@@ -538,10 +538,10 @@ class TestS3KeysUnchangedSensor:
assert self.sensor.inactivity_seconds == period
time_machine.coordinates.shift(10)
- @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook")
- def test_poke_succeeds_on_upload_complete(self, mock_hook, time_machine):
+ def test_poke_succeeds_on_upload_complete(self, time_machine):
time_machine.move_to(DEFAULT_DATE)
- mock_hook.return_value.list_keys.return_value = {"a"}
+ self.sensor.hook = mock.MagicMock()
+ self.sensor.hook.list_keys.return_value = {"a"}
assert not self.sensor.poke(dict())
time_machine.coordinates.shift(10)
assert not self.sensor.poke(dict())
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py
b/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py
index 01533d29887..14c79f1e462 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py
@@ -46,6 +46,9 @@ class TestS3KeyTrigger:
"poke_interval": 5.0,
"should_check_fn": False,
"use_regex": False,
+ "verify": None,
+ "region_name": None,
+ "botocore_config": None,
}
@pytest.mark.asyncio
@@ -106,6 +109,8 @@ class TestS3KeysUnchangedTrigger:
"last_activity_time": None,
"hook_params": {},
"verify": None,
+ "region_name": None,
+ "botocore_config": None,
"polling_period_seconds": 0,
}
diff --git
a/providers/google/tests/unit/google/cloud/transfers/test_s3_to_gcs.py
b/providers/google/tests/unit/google/cloud/transfers/test_s3_to_gcs.py
index 9539e257aff..78821ed041f 100644
--- a/providers/google/tests/unit/google/cloud/transfers/test_s3_to_gcs.py
+++ b/providers/google/tests/unit/google/cloud/transfers/test_s3_to_gcs.py
@@ -98,9 +98,8 @@ class TestS3ToGoogleCloudStorageOperator:
assert operator.poll_interval == POLL_INTERVAL
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook")
- @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook")
- def test_execute(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook):
+ def test_execute(self, gcs_mock_hook, s3_mock_hook):
"""Test the execute function when the run is successful."""
operator = S3ToGCSOperator(
@@ -112,9 +111,9 @@ class TestS3ToGoogleCloudStorageOperator:
dest_gcs=GCS_PATH_PREFIX,
google_impersonation_chain=IMPERSONATION_CHAIN,
)
+ operator.hook = mock.MagicMock()
- s3_one_mock_hook.return_value.list_keys.return_value = MOCK_FILES
- s3_two_mock_hook.return_value.list_keys.return_value = MOCK_FILES
+ operator.hook.list_keys.return_value = MOCK_FILES
uploaded_files = operator.execute(context={})
gcs_mock_hook.return_value.upload.assert_has_calls(
@@ -126,8 +125,8 @@ class TestS3ToGoogleCloudStorageOperator:
any_order=True,
)
- s3_one_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID,
verify=None)
- s3_two_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID,
verify=None)
+ operator.hook.list_keys.assert_called_once()
+ s3_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID,
verify=None)
gcs_mock_hook.assert_called_once_with(
gcp_conn_id=GCS_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -137,9 +136,8 @@ class TestS3ToGoogleCloudStorageOperator:
assert sorted(MOCK_FILES) == sorted(uploaded_files)
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook")
- @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook")
- def test_execute_with_gzip(self, gcs_mock_hook, s3_one_mock_hook,
s3_two_mock_hook):
+ def test_execute_with_gzip(self, gcs_mock_hook, s3_mock_hook):
"""Test the execute function when the run is successful."""
operator = S3ToGCSOperator(
@@ -152,8 +150,9 @@ class TestS3ToGoogleCloudStorageOperator:
gzip=True,
)
- s3_one_mock_hook.return_value.list_keys.return_value = MOCK_FILES
- s3_two_mock_hook.return_value.list_keys.return_value = MOCK_FILES
+ operator.hook = mock.MagicMock()
+
+ operator.hook.list_keys.return_value = MOCK_FILES
operator.execute(context={})
gcs_mock_hook.assert_called_once_with(
@@ -226,13 +225,11 @@ class TestS3ToGoogleCloudStorageOperator:
@pytest.mark.parametrize(*PARAMETRIZED_OBJECT_PATHS)
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook")
- @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook")
def test_execute_apply_gcs_prefix(
self,
gcs_mock_hook,
- s3_one_mock_hook,
- s3_two_mock_hook,
+ s3_mock_hook,
apply_gcs_prefix,
s3_prefix,
s3_object,
@@ -249,9 +246,8 @@ class TestS3ToGoogleCloudStorageOperator:
google_impersonation_chain=IMPERSONATION_CHAIN,
apply_gcs_prefix=apply_gcs_prefix,
)
-
- s3_one_mock_hook.return_value.list_keys.return_value = [s3_prefix +
s3_object]
- s3_two_mock_hook.return_value.list_keys.return_value = [s3_prefix +
s3_object]
+ operator.hook = mock.MagicMock()
+ operator.hook.list_keys.return_value = [s3_prefix + s3_object]
uploaded_files = operator.execute(context={})
gcs_mock_hook.return_value.upload.assert_has_calls(
@@ -261,8 +257,8 @@ class TestS3ToGoogleCloudStorageOperator:
any_order=True,
)
- s3_one_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID,
verify=None)
- s3_two_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID,
verify=None)
+ operator.hook.list_keys.assert_called_once()
+ s3_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID,
verify=None)
gcs_mock_hook.assert_called_once_with(
gcp_conn_id=GCS_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
@@ -306,14 +302,12 @@ class TestS3ToGoogleCloudStorageOperator:
class TestS3ToGoogleCloudStorageOperatorDeferrable:
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.CloudDataTransferServiceHook")
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook")
- @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
@mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook")
- def test_execute_deferrable(self, mock_gcs_hook, mock_s3_super_hook,
mock_s3_hook, mock_transfer_hook):
+ def test_execute_deferrable(self, mock_gcs_hook, mock_s3_hook,
mock_transfer_hook):
mock_gcs_hook.return_value.project_id = PROJECT_ID
- mock_list_keys = mock.MagicMock()
- mock_list_keys.return_value = MOCK_FILES
- mock_s3_super_hook.return_value.list_keys = mock_list_keys
+ mock_s3_super_hook = mock.MagicMock()
+ mock_s3_super_hook.list_keys.return_value = MOCK_FILES
mock_s3_hook.conn_config = mock.MagicMock(
aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
@@ -335,11 +329,13 @@ class TestS3ToGoogleCloudStorageOperatorDeferrable:
deferrable=True,
)
+ operator.hook = mock_s3_super_hook
+
with pytest.raises(TaskDeferred) as exception_info:
operator.execute(None)
- mock_s3_super_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID,
verify=operator.verify)
- mock_list_keys.assert_called_once_with(
+ mock_s3_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID,
verify=operator.verify)
+ mock_s3_super_hook.list_keys.assert_called_once_with(
bucket_name=S3_BUCKET, prefix=S3_PREFIX, delimiter=S3_DELIMITER,
apply_wildcard=False
)
mock_create_transfer_job.assert_called_once()