This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 16e0830a5d Make Http provider sync and async test consistent (#32715)
16e0830a5d is described below

commit 16e0830a5dfe42b9ab0bbca7f8023bf050bbced0
Author: Pankaj Singh <[email protected]>
AuthorDate: Sun Jul 30 04:14:59 2023 +0530

    Make Http provider sync and async test consistent (#32715)
    
    * Http provider async and async test consistent
    
    http provider async test is function based but the sync test is class based
    This PR make them consistent i.e class based
    Fixed some typos and remove redundant code
    
    * Fix static check
    
    * Apply review suggestions
---
 tests/providers/http/hooks/test_http.py | 209 ++++++++++++++------------------
 1 file changed, 92 insertions(+), 117 deletions(-)

diff --git a/tests/providers/http/hooks/test_http.py 
b/tests/providers/http/hooks/test_http.py
index e702dded9c..8970c8eda8 100644
--- a/tests/providers/http/hooks/test_http.py
+++ b/tests/providers/http/hooks/test_http.py
@@ -37,8 +37,17 @@ from airflow.models import Connection
 from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook
 
 
[email protected]
+def aioresponse():
+    """
+    Creates mock async API response.
+    """
+    with aioresponses() as async_response:
+        yield async_response
+
+
 def get_airflow_connection(unused_conn_id=None):
-    return Connection(conn_id="http_default", conn_type="http", 
host="test:8080/", extra='{"bareer": "test"}')
+    return Connection(conn_id="http_default", conn_type="http", 
host="test:8080/", extra='{"bearer": "test"}')
 
 
 def get_airflow_connection_with_port(unused_conn_id=None):
@@ -110,7 +119,7 @@ class TestHttpHook:
             expected_conn = get_airflow_connection()
             conn = self.get_hook.get_conn()
             assert dict(conn.headers, **json.loads(expected_conn.extra)) == 
conn.headers
-            assert conn.headers.get("bareer") == "test"
+            assert conn.headers.get("bearer") == "test"
 
     @mock.patch("requests.Request")
     def test_hook_with_method_in_lowercase(self, mock_requests):
@@ -127,17 +136,17 @@ class TestHttpHook:
             mock_requests.assert_called_once_with(mock.ANY, mock.ANY, 
headers=mock.ANY, params=data)
 
     def test_hook_uses_provided_header(self):
-        conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"})
-        assert conn.headers.get("bareer") == "newT0k3n"
+        conn = self.get_hook.get_conn(headers={"bearer": "newT0k3n"})
+        assert conn.headers.get("bearer") == "newT0k3n"
 
     def test_hook_has_no_header_from_extra(self):
         conn = self.get_hook.get_conn()
-        assert conn.headers.get("bareer") is None
+        assert conn.headers.get("bearer") is None
 
     def test_hooks_header_from_extra_is_overridden(self):
         with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
-            conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"})
-            assert conn.headers.get("bareer") == "newT0k3n"
+            conn = self.get_hook.get_conn(headers={"bearer": "newT0k3n"})
+            assert conn.headers.get("bearer") == "newT0k3n"
 
     def test_post_request(self, requests_mock):
         requests_mock.post(
@@ -213,7 +222,7 @@ class TestHttpHook:
             with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
                 prepared_request = self.get_hook.run("v1/test", 
headers={"some_other_header": "test"})
                 actual = dict(prepared_request.headers)
-                assert actual.get("bareer") == "test"
+                assert actual.get("bearer") == "test"
                 assert actual.get("some_other_header") == "test"
 
     @mock.patch("airflow.providers.http.hooks.http.HttpHook.get_connection")
@@ -395,8 +404,6 @@ class TestHttpHook:
             hook.get_conn()
             auth.assert_not_called()
 
-
-class TestKeepAlive:
     def test_keep_alive_enabled(self):
         with mock.patch(
             "airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection_with_port
@@ -432,125 +439,93 @@ class TestKeepAlive:
             http_send.assert_called()
 
 
-send_email_test = mock.Mock()
-
-
[email protected]
-def aioresponse():
-    """
-    Creates an mock async API response.
-    This comes from a mock library specific to the aiohttp package:
-    https://github.com/pnuckowski/aioresponses
-
-    """
-    with aioresponses() as async_response:
-        yield async_response
-
-
[email protected]
-async def test_do_api_call_async_non_retryable_error(aioresponse):
-    """Test api call asynchronously with non retryable error."""
-    hook = HttpAsyncHook(method="GET")
-    aioresponse.get("http://httpbin.org/non_existent_endpoint";, status=400)
-
-    with pytest.raises(AirflowException) as exc, mock.patch.dict(
-        "os.environ",
-        AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/";,
-    ):
-        await hook.run(endpoint="non_existent_endpoint")
-
-    assert str(exc.value) == "400:Bad Request"
-
-
[email protected]
-async def test_do_api_call_async_retryable_error(caplog, aioresponse):
-    """Test api call asynchronously with retryable error."""
-    caplog.set_level(logging.WARNING, 
logger="airflow.providers.http.hooks.http")
-    hook = HttpAsyncHook(method="GET")
-    aioresponse.get("http://httpbin.org/non_existent_endpoint";, status=500, 
repeat=True)
+class TestHttpAsyncHook:
+    @pytest.mark.asyncio
+    async def test_do_api_call_async_non_retryable_error(self, aioresponse):
+        """Test api call asynchronously with non retryable error."""
+        hook = HttpAsyncHook(method="GET")
+        aioresponse.get("http://httpbin.org/non_existent_endpoint";, status=400)
 
-    with pytest.raises(AirflowException) as exc, mock.patch.dict(
-        "os.environ",
-        AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/";,
-    ):
-        await hook.run(endpoint="non_existent_endpoint")
-
-    assert str(exc.value) == "500:Internal Server Error"
-    assert "[Try 3 of 3] Request to http://httpbin.org/non_existent_endpoint 
failed" in caplog.text
-
-
[email protected]
-async def test_do_api_call_async_unknown_method():
-    """Test api call asynchronously for unknown method."""
-    hook = HttpAsyncHook(method="NOPE")
-    json = {
-        "existing_cluster_id": "xxxx-xxxxxx-xxxxxx",
-    }
-
-    with pytest.raises(AirflowException) as exc:
-        await hook.run(endpoint="non_existent_endpoint", data=json)
-
-    assert str(exc.value) == "Unexpected HTTP Method: NOPE"
+        with pytest.raises(AirflowException, match="400:Bad Request"), 
mock.patch.dict(
+            "os.environ",
+            AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/";,
+        ):
+            await hook.run(endpoint="non_existent_endpoint")
+
+    @pytest.mark.asyncio
+    async def test_do_api_call_async_retryable_error(self, caplog, 
aioresponse):
+        """Test api call asynchronously with retryable error."""
+        caplog.set_level(logging.WARNING, 
logger="airflow.providers.http.hooks.http")
+        hook = HttpAsyncHook(method="GET")
+        aioresponse.get("http://httpbin.org/non_existent_endpoint";, 
status=500, repeat=True)
+
+        with pytest.raises(AirflowException, match="500:Internal Server 
Error"), mock.patch.dict(
+            "os.environ",
+            AIRFLOW_CONN_HTTP_DEFAULT="http://httpbin.org/";,
+        ):
+            await hook.run(endpoint="non_existent_endpoint")
 
+        assert "[Try 3 of 3] Request to 
http://httpbin.org/non_existent_endpoint failed" in caplog.text
 
[email protected]
-async def test_async_post_request(aioresponse):
-    """Test api call asynchronously for POST request."""
-    hook = HttpAsyncHook()
+    @pytest.mark.asyncio
+    async def test_do_api_call_async_unknown_method(self):
+        """Test api call asynchronously for unknown http method."""
+        hook = HttpAsyncHook(method="NOPE")
+        json = {"existing_cluster_id": "xxxx-xxxxxx-xxxxxx"}
 
-    aioresponse.post(
-        "http://test:8080/v1/test";,
-        status=200,
-        payload='{"status":{"status": 200}}',
-        reason="OK",
-    )
+        with pytest.raises(AirflowException, match="Unexpected HTTP Method: 
NOPE"):
+            await hook.run(endpoint="non_existent_endpoint", data=json)
 
-    with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
-        resp = await hook.run("v1/test")
-        assert resp.status == 200
+    @pytest.mark.asyncio
+    async def test_async_post_request(self, aioresponse):
+        """Test api call asynchronously for POST request."""
+        hook = HttpAsyncHook()
 
+        aioresponse.post(
+            "http://test:8080/v1/test";,
+            status=200,
+            payload='{"status":{"status": 200}}',
+            reason="OK",
+        )
 
[email protected]
-async def test_async_post_request_with_error_code(aioresponse):
-    """Test api call asynchronously for POST request with error."""
-    hook = HttpAsyncHook()
+        with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
+            resp = await hook.run("v1/test")
+            assert resp.status == 200
 
-    aioresponse.post(
-        "http://test:8080/v1/test";,
-        status=418,
-        payload='{"status":{"status": 418}}',
-        reason="I am teapot",
-    )
+    @pytest.mark.asyncio
+    async def test_async_post_request_with_error_code(self, aioresponse):
+        """Test api call asynchronously for POST request with error."""
+        hook = HttpAsyncHook()
 
-    with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
-        with pytest.raises(AirflowException):
-            await hook.run("v1/test")
+        aioresponse.post(
+            "http://test:8080/v1/test";,
+            status=418,
+            payload='{"status":{"status": 418}}',
+            reason="I am teapot",
+        )
 
+        with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
+            with pytest.raises(AirflowException):
+                await hook.run("v1/test")
 
[email protected]
-async def test_async_request_uses_connection_extra(aioresponse):
-    """Test api call asynchronously with a connection that has extra field."""
+    @pytest.mark.asyncio
+    async def test_async_request_uses_connection_extra(self, aioresponse):
+        """Test api call asynchronously with a connection that has extra 
field."""
 
-    connection_extra = {"bareer": "test"}
-    connection_id = "http_default"
+        connection_extra = {"bearer": "test"}
 
-    def get_airflow_connection_with_extra(unused_conn_id=None):
-        return Connection(
-            conn_id=connection_id, conn_type="http", host="test:8080/", 
extra=json.dumps(connection_extra)
+        aioresponse.post(
+            "http://test:8080/v1/test";,
+            status=200,
+            payload='{"status":{"status": 200}}',
+            reason="OK",
         )
 
-    aioresponse.post(
-        "http://test:8080/v1/test";,
-        status=200,
-        payload='{"status":{"status": 200}}',
-        reason="OK",
-    )
-
-    with mock.patch(
-        "airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection_with_extra
-    ):
-        hook = HttpAsyncHook()
-        with mock.patch("aiohttp.ClientSession.post", 
new_callable=mock.AsyncMock) as mocked_function:
-            await hook.run("v1/test")
-            headers = mocked_function.call_args.kwargs.get("headers")
-            assert all(key in headers and headers[key] == value for key, value 
in connection_extra.items())
+        with mock.patch("airflow.hooks.base.BaseHook.get_connection", 
side_effect=get_airflow_connection):
+            hook = HttpAsyncHook()
+            with mock.patch("aiohttp.ClientSession.post", 
new_callable=mock.AsyncMock) as mocked_function:
+                await hook.run("v1/test")
+                headers = mocked_function.call_args.kwargs.get("headers")
+                assert all(
+                    key in headers and headers[key] == value for key, value in 
connection_extra.items()
+                )

Reply via email to