This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 9edfcb7  Support extra_args in S3Hook and GCSToS3Operator (#11001)
9edfcb7 is described below

commit 9edfcb7ac46917836ec956264da8876e58d92392
Author: Shekhar Singh <[email protected]>
AuthorDate: Sat Sep 19 06:33:21 2020 +0530

    Support extra_args in S3Hook and GCSToS3Operator (#11001)
---
 airflow/providers/amazon/aws/hooks/s3.py           | 13 ++++--
 .../providers/amazon/aws/transfers/gcs_to_s3.py    |  8 +++-
 tests/providers/amazon/aws/hooks/test_s3.py        | 21 +++++++++
 .../amazon/aws/transfers/test_gcs_to_s3.py         | 51 ++++++++++++++++++++++
 4 files changed, 88 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/s3.py 
b/airflow/providers/amazon/aws/hooks/s3.py
index 21ed054..30adcc3 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -109,6 +109,14 @@ class S3Hook(AwsBaseHook):
 
     def __init__(self, *args, **kwargs) -> None:
         kwargs['client_type'] = 's3'
+
+        self.extra_args = {}
+        if 'extra_args' in kwargs:
+            self.extra_args = kwargs['extra_args']
+            if not isinstance(self.extra_args, dict):
+                raise ValueError("extra_args '%r' must be of type %s" % 
(self.extra_args, dict))
+            del kwargs['extra_args']
+
         super().__init__(*args, **kwargs)
 
     @staticmethod
@@ -485,11 +493,10 @@ class S3Hook(AwsBaseHook):
         if not replace and self.check_for_key(key, bucket_name):
             raise ValueError("The key {key} already exists.".format(key=key))
 
-        extra_args = {}
+        extra_args = self.extra_args
         if encrypt:
             extra_args['ServerSideEncryption'] = "AES256"
         if gzip:
-            filename_gz = ''
             with open(filename, 'rb') as f_in:
                 filename_gz = f_in.name + '.gz'
                 with gz.open(filename_gz, 'wb') as f_out:
@@ -625,7 +632,7 @@ class S3Hook(AwsBaseHook):
         if not replace and self.check_for_key(key, bucket_name):
             raise ValueError("The key {key} already exists.".format(key=key))
 
-        extra_args = {}
+        extra_args = self.extra_args
         if encrypt:
             extra_args['ServerSideEncryption'] = "AES256"
         if acl_policy:
diff --git a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py 
b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
index 1a13c9d..9b00e33 100644
--- a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
@@ -19,7 +19,7 @@
 This module contains Google Cloud Storage to S3 operator.
 """
 import warnings
-from typing import Iterable, Optional, Sequence, Union
+from typing import Iterable, Optional, Sequence, Union, Dict
 
 from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -108,6 +108,7 @@ class GCSToS3Operator(BaseOperator):
         dest_verify=None,
         replace=False,
         google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+        dest_s3_extra_args: Optional[Dict] = None,
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -131,6 +132,7 @@ class GCSToS3Operator(BaseOperator):
         self.dest_verify = dest_verify
         self.replace = replace
         self.google_impersonation_chain = google_impersonation_chain
+        self.dest_s3_extra_args = dest_s3_extra_args or {}
 
     def execute(self, context):
         # list all files in an Google Cloud Storage bucket
@@ -149,7 +151,9 @@ class GCSToS3Operator(BaseOperator):
 
         files = hook.list(bucket_name=self.bucket, prefix=self.prefix, 
delimiter=self.delimiter)
 
-        s3_hook = S3Hook(aws_conn_id=self.dest_aws_conn_id, 
verify=self.dest_verify)
+        s3_hook = S3Hook(
+            aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify, 
extra_args=self.dest_s3_extra_args
+        )
 
         if not self.replace:
             # if we are not replacing -> list all files in the S3 bucket
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py 
b/tests/providers/amazon/aws/hooks/test_s3.py
index 3aa58c9..f83d93b 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -425,3 +425,24 @@ class TestAwsS3Hook:
         params = {x[0]: x[1] for x in [x.split("=") for x in 
url[0:].split("&")]}
 
         assert {"AWSAccessKeyId", "Signature", 
"Expires"}.issubset(set(params.keys()))
+
+    def test_should_throw_error_if_extra_args_is_not_dict(self):
+        with pytest.raises(ValueError):
+            S3Hook(extra_args=1)
+
+    def test_should_throw_error_if_extra_args_contains_unknown_arg(self, 
s3_bucket):
+        hook = S3Hook(extra_args={"unknown_s3_args": "value"})
+        with tempfile.TemporaryFile() as temp_file:
+            temp_file.write(b"Content")
+            temp_file.seek(0)
+            with pytest.raises(ValueError):
+                hook.load_file_obj(temp_file, "my_key", s3_bucket, 
acl_policy='public-read')
+
+    def test_should_pass_extra_args(self, s3_bucket):
+        hook = S3Hook(extra_args={"ContentLanguage": "value"})
+        with tempfile.TemporaryFile() as temp_file:
+            temp_file.write(b"Content")
+            temp_file.seek(0)
+            hook.load_file_obj(temp_file, "my_key", s3_bucket, 
acl_policy='public-read')
+            resource = boto3.resource('s3').Object(s3_bucket, 'my_key')  # 
pylint: disable=no-member
+            assert resource.get()['ContentLanguage'] == "value"
diff --git a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py 
b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
index 83b1239..fd84874 100644
--- a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
@@ -189,3 +189,54 @@ class TestGCSToS3Operator(unittest.TestCase):
         uploaded_files = operator.execute(None)
         self.assertEqual(sorted(MOCK_FILES), sorted(uploaded_files))
         self.assertEqual(sorted(MOCK_FILES), sorted(hook.list_keys('bucket', 
delimiter='/')))
+
+    @mock_s3
+    @mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook')
+    @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook')
+    @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.S3Hook')
+    def test_execute_should_handle_with_default_dest_s3_extra_args(self, 
s3_mock_hook, mock_hook, mock_hook2):
+        mock_hook.return_value.list.return_value = MOCK_FILES
+        mock_hook.return_value.download.return_value = b"testing"
+        mock_hook2.return_value.list.return_value = MOCK_FILES
+        s3_mock_hook.return_value = mock.Mock()
+        s3_mock_hook.parse_s3_url.return_value = mock.Mock()
+
+        operator = GCSToS3Operator(
+            task_id=TASK_ID,
+            bucket=GCS_BUCKET,
+            prefix=PREFIX,
+            delimiter=DELIMITER,
+            dest_aws_conn_id="aws_default",
+            dest_s3_key=S3_BUCKET,
+            replace=True,
+        )
+        operator.execute(None)
+        s3_mock_hook.assert_called_once_with(aws_conn_id='aws_default', 
extra_args={}, verify=None)
+
+    @mock_s3
+    @mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook')
+    @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook')
+    @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.S3Hook')
+    def test_execute_should_pass_dest_s3_extra_args_to_s3_hook(self, 
s3_mock_hook, mock_hook, mock_hook2):
+        mock_hook.return_value.list.return_value = MOCK_FILES
+        mock_hook.return_value.download.return_value = b"testing"
+        mock_hook2.return_value.list.return_value = MOCK_FILES
+        s3_mock_hook.return_value = mock.Mock()
+        s3_mock_hook.parse_s3_url.return_value = mock.Mock()
+
+        operator = GCSToS3Operator(
+            task_id=TASK_ID,
+            bucket=GCS_BUCKET,
+            prefix=PREFIX,
+            delimiter=DELIMITER,
+            dest_aws_conn_id="aws_default",
+            dest_s3_key=S3_BUCKET,
+            replace=True,
+            dest_s3_extra_args={
+                "ContentLanguage": "value",
+            },
+        )
+        operator.execute(None)
+        s3_mock_hook.assert_called_once_with(
+            aws_conn_id='aws_default', extra_args={'ContentLanguage': 
'value'}, verify=None
+        )

Reply via email to