This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9b8a093e07 Use tmp_path in amazon s3 test (#33705)
9b8a093e07 is described below

commit 9b8a093e0738f8491b8dc698e99e68aa7780a669
Author: Miroslav Šedivý <[email protected]>
AuthorDate: Thu Aug 24 22:07:56 2023 +0000

    Use tmp_path in amazon s3 test (#33705)
---
 tests/providers/amazon/aws/hooks/test_s3.py | 214 +++++++++++++---------------
 1 file changed, 98 insertions(+), 116 deletions(-)

diff --git a/tests/providers/amazon/aws/hooks/test_s3.py 
b/tests/providers/amazon/aws/hooks/test_s3.py
index 072137dc5f..686ef943ca 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -21,9 +21,7 @@ import gzip as gz
 import inspect
 import os
 import re
-import tempfile
 import unittest
-from pathlib import Path
 from unittest import mock, mock as async_mock
 from unittest.mock import MagicMock, Mock, patch
 
@@ -827,63 +825,52 @@ class TestAwsS3Hook:
             response["Grants"][0]["Permission"] == "FULL_CONTROL"
         )
 
-    def test_load_fileobj(self, s3_bucket):
+    def test_load_fileobj(self, s3_bucket, tmp_path):
         hook = S3Hook()
-        with tempfile.TemporaryFile() as temp_file:
-            temp_file.write(b"Content")
-            temp_file.seek(0)
-            hook.load_file_obj(temp_file, "my_key", s3_bucket)
-            resource = boto3.resource("s3").Object(s3_bucket, "my_key")
-            assert resource.get()["Body"].read() == b"Content"
-
-    def test_load_fileobj_acl(self, s3_bucket):
+        path = tmp_path / "testfile"
+        path.write_text("Content")
+        hook.load_file_obj(path.open("rb"), "my_key", s3_bucket)
+        resource = boto3.resource("s3").Object(s3_bucket, "my_key")
+        assert resource.get()["Body"].read() == b"Content"
+
+    def test_load_fileobj_acl(self, s3_bucket, tmp_path):
         hook = S3Hook()
-        with tempfile.TemporaryFile() as temp_file:
-            temp_file.write(b"Content")
-            temp_file.seek(0)
-            hook.load_file_obj(temp_file, "my_key", s3_bucket, 
acl_policy="public-read")
-            response = boto3.client("s3").get_object_acl(
-                Bucket=s3_bucket, Key="my_key", RequestPayer="requester"
-            )
-            assert (response["Grants"][1]["Permission"] == "READ") and (
-                response["Grants"][0]["Permission"] == "FULL_CONTROL"
-            )
+        path = tmp_path / "testfile"
+        path.write_text("Content")
+        hook.load_file_obj(path.open("rb"), "my_key", s3_bucket, 
acl_policy="public-read")
+        response = boto3.client("s3").get_object_acl(Bucket=s3_bucket, 
Key="my_key", RequestPayer="requester")
+        assert (response["Grants"][1]["Permission"] == "READ") and (
+            response["Grants"][0]["Permission"] == "FULL_CONTROL"
+        )
 
-    def test_load_file_gzip(self, s3_bucket):
+    def test_load_file_gzip(self, s3_bucket, tmp_path):
         hook = S3Hook()
-        with tempfile.NamedTemporaryFile(delete=False) as temp_file:
-            temp_file.write(b"Content")
-            temp_file.seek(0)
-            hook.load_file(temp_file.name, "my_key", s3_bucket, gzip=True)
-            resource = boto3.resource("s3").Object(s3_bucket, "my_key")
-            assert gz.decompress(resource.get()["Body"].read()) == b"Content"
-            os.unlink(temp_file.name)
-
-    def test_load_file_acl(self, s3_bucket):
+        path = tmp_path / "testfile"
+        path.write_text("Content")
+        hook.load_file(path, "my_key", s3_bucket, gzip=True)
+        resource = boto3.resource("s3").Object(s3_bucket, "my_key")
+        assert gz.decompress(resource.get()["Body"].read()) == b"Content"
+
+    def test_load_file_acl(self, s3_bucket, tmp_path):
         hook = S3Hook()
-        with tempfile.NamedTemporaryFile(delete=False) as temp_file:
-            temp_file.write(b"Content")
-            temp_file.seek(0)
-            hook.load_file(temp_file.name, "my_key", s3_bucket, gzip=True, 
acl_policy="public-read")
-            response = boto3.client("s3").get_object_acl(
-                Bucket=s3_bucket, Key="my_key", RequestPayer="requester"
-            )
-            assert (response["Grants"][1]["Permission"] == "READ") and (
-                response["Grants"][0]["Permission"] == "FULL_CONTROL"
-            )
-            os.unlink(temp_file.name)
+        path = tmp_path / "testfile"
+        path.write_text("Content")
+        hook.load_file(path, "my_key", s3_bucket, gzip=True, 
acl_policy="public-read")
+        response = boto3.client("s3").get_object_acl(Bucket=s3_bucket, 
Key="my_key", RequestPayer="requester")
+        assert (response["Grants"][1]["Permission"] == "READ") and (
+            response["Grants"][0]["Permission"] == "FULL_CONTROL"
+        )
 
-    def test_copy_object_acl(self, s3_bucket):
+    def test_copy_object_acl(self, s3_bucket, tmp_path):
         hook = S3Hook()
-        with tempfile.NamedTemporaryFile() as temp_file:
-            temp_file.write(b"Content")
-            temp_file.seek(0)
-            hook.load_file_obj(temp_file, "my_key", s3_bucket)
-            hook.copy_object("my_key", "my_key2", s3_bucket, s3_bucket)
-            response = boto3.client("s3").get_object_acl(
-                Bucket=s3_bucket, Key="my_key2", RequestPayer="requester"
-            )
-            assert (response["Grants"][0]["Permission"] == "FULL_CONTROL") and 
(len(response["Grants"]) == 1)
+        path = tmp_path / "testfile"
+        path.write_text("Content")
+        hook.load_file_obj(path.open("rb"), "my_key", s3_bucket)
+        hook.copy_object("my_key", "my_key2", s3_bucket, s3_bucket)
+        response = boto3.client("s3").get_object_acl(
+            Bucket=s3_bucket, Key="my_key2", RequestPayer="requester"
+        )
+        assert (response["Grants"][0]["Permission"] == "FULL_CONTROL") and 
(len(response["Grants"]) == 1)
 
     @mock_s3
     def test_delete_bucket_if_bucket_exist(self, s3_bucket):
@@ -974,34 +961,33 @@ class TestAwsS3Hook:
         assert isinstance(ctx.value, ValueError)
 
     @mock.patch("airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile")
-    def test_download_file(self, mock_temp_file):
-        with tempfile.NamedTemporaryFile(dir="/tmp", 
prefix="airflow_tmp_test_s3_hook") as temp_file:
-            mock_temp_file.return_value = temp_file
-            s3_hook = S3Hook(aws_conn_id="s3_test")
-            s3_hook.check_for_key = Mock(return_value=True)
-            s3_obj = Mock()
-            s3_obj.download_fileobj = Mock(return_value=None)
-            s3_hook.get_key = Mock(return_value=s3_obj)
-            key = "test_key"
-            bucket = "test_bucket"
-
-            output_file = s3_hook.download_file(key=key, bucket_name=bucket)
-
-            s3_hook.get_key.assert_called_once_with(key, bucket)
-            s3_obj.download_fileobj.assert_called_once_with(
-                temp_file,
-                Config=s3_hook.transfer_config,
-                ExtraArgs=s3_hook.extra_args,
-            )
+    def test_download_file(self, mock_temp_file, tmp_path):
+        path = tmp_path / "airflow_tmp_test_s3_hook"
+        mock_temp_file.return_value = path
+        s3_hook = S3Hook(aws_conn_id="s3_test")
+        s3_hook.check_for_key = Mock(return_value=True)
+        s3_obj = Mock()
+        s3_obj.download_fileobj = Mock(return_value=None)
+        s3_hook.get_key = Mock(return_value=s3_obj)
+        key = "test_key"
+        bucket = "test_bucket"
+
+        output_file = s3_hook.download_file(key=key, bucket_name=bucket)
 
-            assert temp_file.name == output_file
+        s3_hook.get_key.assert_called_once_with(key, bucket)
+        s3_obj.download_fileobj.assert_called_once_with(
+            path,
+            Config=s3_hook.transfer_config,
+            ExtraArgs=s3_hook.extra_args,
+        )
+
+        assert path.name == output_file
 
     @mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
-    def test_download_file_with_preserve_name(self, mock_open):
-        file_name = "test.log"
+    def test_download_file_with_preserve_name(self, mock_open, tmp_path):
+        path = tmp_path / "test.log"
         bucket = "test_bucket"
-        key = f"test_key/{file_name}"
-        local_folder = "/tmp"
+        key = f"test_key/{path.name}"
 
         s3_hook = S3Hook(aws_conn_id="s3_test")
         s3_hook.check_for_key = Mock(return_value=True)
@@ -1012,19 +998,18 @@ class TestAwsS3Hook:
         s3_hook.download_file(
             key=key,
             bucket_name=bucket,
-            local_path=local_folder,
+            local_path=os.fspath(path.parent),
             preserve_file_name=True,
             use_autogenerated_subdir=False,
         )
 
-        mock_open.assert_called_once_with(Path(local_folder, file_name), "wb")
+        mock_open.assert_called_once_with(path, "wb")
 
     @mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
-    def test_download_file_with_preserve_name_with_autogenerated_subdir(self, 
mock_open):
-        file_name = "test.log"
+    def test_download_file_with_preserve_name_with_autogenerated_subdir(self, 
mock_open, tmp_path):
+        path = tmp_path / "test.log"
         bucket = "test_bucket"
-        key = f"test_key/{file_name}"
-        local_folder = "/tmp"
+        key = f"test_key/{path.name}"
 
         s3_hook = S3Hook(aws_conn_id="s3_test")
         s3_hook.check_for_key = Mock(return_value=True)
@@ -1035,33 +1020,32 @@ class TestAwsS3Hook:
         result_file = s3_hook.download_file(
             key=key,
             bucket_name=bucket,
-            local_path=local_folder,
+            local_path=os.fspath(path.parent),
             preserve_file_name=True,
             use_autogenerated_subdir=True,
         )
 
         assert result_file.rsplit("/", 1)[-2].startswith("airflow_tmp_dir_")
 
-    def test_download_file_with_preserve_name_file_already_exists(self):
-        with tempfile.NamedTemporaryFile(dir="/tmp", 
prefix="airflow_tmp_test_s3_hook") as file:
-            file_name = file.name.rsplit("/", 1)[-1]
-            bucket = "test_bucket"
-            key = f"test_key/{file_name}"
-            local_folder = "/tmp"
-            s3_hook = S3Hook(aws_conn_id="s3_test")
-            s3_hook.check_for_key = Mock(return_value=True)
-            s3_obj = Mock()
-            s3_obj.key = f"s3://{bucket}/{key}"
-            s3_obj.download_fileobj = Mock(return_value=None)
-            s3_hook.get_key = Mock(return_value=s3_obj)
-            with pytest.raises(FileExistsError):
-                s3_hook.download_file(
-                    key=key,
-                    bucket_name=bucket,
-                    local_path=local_folder,
-                    preserve_file_name=True,
-                    use_autogenerated_subdir=False,
-                )
+    def test_download_file_with_preserve_name_file_already_exists(self, 
tmp_path):
+        path = tmp_path / "airflow_tmp_test_s3_hook"
+        path.write_text("")
+        bucket = "test_bucket"
+        key = f"test_key/{path.name}"
+        s3_hook = S3Hook(aws_conn_id="s3_test")
+        s3_hook.check_for_key = Mock(return_value=True)
+        s3_obj = Mock()
+        s3_obj.key = f"s3://{bucket}/{key}"
+        s3_obj.download_fileobj = Mock(return_value=None)
+        s3_hook.get_key = Mock(return_value=s3_obj)
+        with pytest.raises(FileExistsError):
+            s3_hook.download_file(
+                key=key,
+                bucket_name=bucket,
+                local_path=os.fspath(path.parent),
+                preserve_file_name=True,
+                use_autogenerated_subdir=False,
+            )
 
     def test_generate_presigned_url(self, s3_bucket):
         hook = S3Hook()
@@ -1078,22 +1062,20 @@ class TestAwsS3Hook:
         with pytest.raises(TypeError, match="extra_args expected dict, got 
.*"):
             S3Hook(extra_args=1)
 
-    def test_should_throw_error_if_extra_args_contains_unknown_arg(self, 
s3_bucket):
+    def test_should_throw_error_if_extra_args_contains_unknown_arg(self, 
s3_bucket, tmp_path):
         hook = S3Hook(extra_args={"unknown_s3_args": "value"})
-        with tempfile.TemporaryFile() as temp_file:
-            temp_file.write(b"Content")
-            temp_file.seek(0)
-            with pytest.raises(ValueError):
-                hook.load_file_obj(temp_file, "my_key", s3_bucket, 
acl_policy="public-read")
+        path = tmp_path / "testfile"
+        path.write_text("Content")
+        with pytest.raises(ValueError):
+            hook.load_file_obj(path.open("rb"), "my_key", s3_bucket, 
acl_policy="public-read")
 
-    def test_should_pass_extra_args(self, s3_bucket):
+    def test_should_pass_extra_args(self, s3_bucket, tmp_path):
         hook = S3Hook(extra_args={"ContentLanguage": "value"})
-        with tempfile.TemporaryFile() as temp_file:
-            temp_file.write(b"Content")
-            temp_file.seek(0)
-            hook.load_file_obj(temp_file, "my_key", s3_bucket, 
acl_policy="public-read")
-            resource = boto3.resource("s3").Object(s3_bucket, "my_key")
-            assert resource.get()["ContentLanguage"] == "value"
+        path = tmp_path / "testfile"
+        path.write_text("Content")
+        hook.load_file_obj(path.open("rb"), "my_key", s3_bucket, 
acl_policy="public-read")
+        resource = boto3.resource("s3").Object(s3_bucket, "my_key")
+        assert resource.get()["ContentLanguage"] == "value"
 
     def test_that_extra_args_not_changed_between_calls(self, s3_bucket):
         original = {

Reply via email to