This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 9edfcb7 Support extra_args in S3Hook and GCSToS3Operator (#11001)
9edfcb7 is described below
commit 9edfcb7ac46917836ec956264da8876e58d92392
Author: Shekhar Singh <[email protected]>
AuthorDate: Sat Sep 19 06:33:21 2020 +0530
Support extra_args in S3Hook and GCSToS3Operator (#11001)
---
airflow/providers/amazon/aws/hooks/s3.py | 13 ++++--
.../providers/amazon/aws/transfers/gcs_to_s3.py | 8 +++-
tests/providers/amazon/aws/hooks/test_s3.py | 21 +++++++++
.../amazon/aws/transfers/test_gcs_to_s3.py | 51 ++++++++++++++++++++++
4 files changed, 88 insertions(+), 5 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/s3.py
b/airflow/providers/amazon/aws/hooks/s3.py
index 21ed054..30adcc3 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -109,6 +109,14 @@ class S3Hook(AwsBaseHook):
def __init__(self, *args, **kwargs) -> None:
kwargs['client_type'] = 's3'
+
+ self.extra_args = {}
+ if 'extra_args' in kwargs:
+ self.extra_args = kwargs['extra_args']
+ if not isinstance(self.extra_args, dict):
+ raise ValueError("extra_args '%r' must be of type %s" %
(self.extra_args, dict))
+ del kwargs['extra_args']
+
super().__init__(*args, **kwargs)
@staticmethod
@@ -485,11 +493,10 @@ class S3Hook(AwsBaseHook):
if not replace and self.check_for_key(key, bucket_name):
raise ValueError("The key {key} already exists.".format(key=key))
- extra_args = {}
+ extra_args = self.extra_args
if encrypt:
extra_args['ServerSideEncryption'] = "AES256"
if gzip:
- filename_gz = ''
with open(filename, 'rb') as f_in:
filename_gz = f_in.name + '.gz'
with gz.open(filename_gz, 'wb') as f_out:
@@ -625,7 +632,7 @@ class S3Hook(AwsBaseHook):
if not replace and self.check_for_key(key, bucket_name):
raise ValueError("The key {key} already exists.".format(key=key))
- extra_args = {}
+ extra_args = self.extra_args
if encrypt:
extra_args['ServerSideEncryption'] = "AES256"
if acl_policy:
diff --git a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
index 1a13c9d..9b00e33 100644
--- a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
@@ -19,7 +19,7 @@
This module contains Google Cloud Storage to S3 operator.
"""
import warnings
-from typing import Iterable, Optional, Sequence, Union
+from typing import Iterable, Optional, Sequence, Union, Dict
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -108,6 +108,7 @@ class GCSToS3Operator(BaseOperator):
dest_verify=None,
replace=False,
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ dest_s3_extra_args: Optional[Dict] = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -131,6 +132,7 @@ class GCSToS3Operator(BaseOperator):
self.dest_verify = dest_verify
self.replace = replace
self.google_impersonation_chain = google_impersonation_chain
+ self.dest_s3_extra_args = dest_s3_extra_args or {}
def execute(self, context):
# list all files in an Google Cloud Storage bucket
@@ -149,7 +151,9 @@ class GCSToS3Operator(BaseOperator):
files = hook.list(bucket_name=self.bucket, prefix=self.prefix,
delimiter=self.delimiter)
- s3_hook = S3Hook(aws_conn_id=self.dest_aws_conn_id,
verify=self.dest_verify)
+ s3_hook = S3Hook(
+ aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify,
extra_args=self.dest_s3_extra_args
+ )
if not self.replace:
# if we are not replacing -> list all files in the S3 bucket
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py
b/tests/providers/amazon/aws/hooks/test_s3.py
index 3aa58c9..f83d93b 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -425,3 +425,24 @@ class TestAwsS3Hook:
params = {x[0]: x[1] for x in [x.split("=") for x in
url[0:].split("&")]}
assert {"AWSAccessKeyId", "Signature",
"Expires"}.issubset(set(params.keys()))
+
+ def test_should_throw_error_if_extra_args_is_not_dict(self):
+ with pytest.raises(ValueError):
+ S3Hook(extra_args=1)
+
+ def test_should_throw_error_if_extra_args_contains_unknown_arg(self,
s3_bucket):
+ hook = S3Hook(extra_args={"unknown_s3_args": "value"})
+ with tempfile.TemporaryFile() as temp_file:
+ temp_file.write(b"Content")
+ temp_file.seek(0)
+ with pytest.raises(ValueError):
+ hook.load_file_obj(temp_file, "my_key", s3_bucket,
acl_policy='public-read')
+
+ def test_should_pass_extra_args(self, s3_bucket):
+ hook = S3Hook(extra_args={"ContentLanguage": "value"})
+ with tempfile.TemporaryFile() as temp_file:
+ temp_file.write(b"Content")
+ temp_file.seek(0)
+ hook.load_file_obj(temp_file, "my_key", s3_bucket,
acl_policy='public-read')
+ resource = boto3.resource('s3').Object(s3_bucket, 'my_key') #
pylint: disable=no-member
+ assert resource.get()['ContentLanguage'] == "value"
diff --git a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
index 83b1239..fd84874 100644
--- a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
@@ -189,3 +189,54 @@ class TestGCSToS3Operator(unittest.TestCase):
uploaded_files = operator.execute(None)
self.assertEqual(sorted(MOCK_FILES), sorted(uploaded_files))
self.assertEqual(sorted(MOCK_FILES), sorted(hook.list_keys('bucket',
delimiter='/')))
+
+ @mock_s3
+ @mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook')
+ @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook')
+ @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.S3Hook')
+ def test_execute_should_handle_with_default_dest_s3_extra_args(self,
s3_mock_hook, mock_hook, mock_hook2):
+ mock_hook.return_value.list.return_value = MOCK_FILES
+ mock_hook.return_value.download.return_value = b"testing"
+ mock_hook2.return_value.list.return_value = MOCK_FILES
+ s3_mock_hook.return_value = mock.Mock()
+ s3_mock_hook.parse_s3_url.return_value = mock.Mock()
+
+ operator = GCSToS3Operator(
+ task_id=TASK_ID,
+ bucket=GCS_BUCKET,
+ prefix=PREFIX,
+ delimiter=DELIMITER,
+ dest_aws_conn_id="aws_default",
+ dest_s3_key=S3_BUCKET,
+ replace=True,
+ )
+ operator.execute(None)
+ s3_mock_hook.assert_called_once_with(aws_conn_id='aws_default',
extra_args={}, verify=None)
+
+ @mock_s3
+ @mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook')
+ @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook')
+ @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.S3Hook')
+ def test_execute_should_pass_dest_s3_extra_args_to_s3_hook(self,
s3_mock_hook, mock_hook, mock_hook2):
+ mock_hook.return_value.list.return_value = MOCK_FILES
+ mock_hook.return_value.download.return_value = b"testing"
+ mock_hook2.return_value.list.return_value = MOCK_FILES
+ s3_mock_hook.return_value = mock.Mock()
+ s3_mock_hook.parse_s3_url.return_value = mock.Mock()
+
+ operator = GCSToS3Operator(
+ task_id=TASK_ID,
+ bucket=GCS_BUCKET,
+ prefix=PREFIX,
+ delimiter=DELIMITER,
+ dest_aws_conn_id="aws_default",
+ dest_s3_key=S3_BUCKET,
+ replace=True,
+ dest_s3_extra_args={
+ "ContentLanguage": "value",
+ },
+ )
+ operator.execute(None)
+ s3_mock_hook.assert_called_once_with(
+ aws_conn_id='aws_default', extra_args={'ContentLanguage':
'value'}, verify=None
+ )