This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 19316daad8 migrated some part of test always to pytest (#29375)
19316daad8 is described below

commit 19316daad84d6d7765185e5f0f937894fc0162e8
Author: Abhishek-kumar-samsung 
<[email protected]>
AuthorDate: Mon Feb 13 01:39:28 2023 +0530

    migrated some part of test always to pytest (#29375)
---
 tests/always/test_connection.py        | 46 ++++++++++++++++++----------------
 tests/always/test_project_structure.py | 12 ++++-----
 2 files changed, 29 insertions(+), 29 deletions(-)

diff --git a/tests/always/test_connection.py b/tests/always/test_connection.py
index 9ed2986f09..70d57df099 100644
--- a/tests/always/test_connection.py
+++ b/tests/always/test_connection.py
@@ -20,14 +20,12 @@ from __future__ import annotations
 import json
 import os
 import re
-import unittest
 from collections import namedtuple
 from unittest import mock
 
 import pytest
 import sqlalchemy
 from cryptography.fernet import Fernet
-from parameterized import parameterized
 
 from airflow import AirflowException
 from airflow.hooks.base import BaseHook
@@ -61,16 +59,15 @@ class UriTestCaseConfig:
         return f"{func.__name__}_{num}_{param.args[0].description.replace(' ', 
'_')}"
 
 
-class TestConnection(unittest.TestCase):
-    def setUp(self):
+class TestConnection:
+    def setup_method(self):
         crypto._fernet = None
-        patcher = mock.patch("airflow.models.connection.mask_secret", 
autospec=True)
-        self.mask_secret = patcher.start()
+        self.patcher = mock.patch("airflow.models.connection.mask_secret", 
autospec=True)
+        self.mask_secret = self.patcher.start()
 
-        self.addCleanup(patcher.stop)
-
-    def tearDown(self):
+    def teardown_method(self):
         crypto._fernet = None
+        self.patcher.stop()
 
     @conf_vars({("core", "fernet_key"): ""})
     def test_connection_extra_no_encryption(self):
@@ -350,7 +347,7 @@ class TestConnection(unittest.TestCase):
         ),
     ]
 
-    @parameterized.expand([(x,) for x in test_from_uri_params], 
UriTestCaseConfig.uri_test_name)
+    @pytest.mark.parametrize("test_config", [x for x in test_from_uri_params])
     def test_connection_from_uri(self, test_config: UriTestCaseConfig):
 
         connection = Connection(uri=test_config.test_uri)
@@ -372,7 +369,7 @@ class TestConnection(unittest.TestCase):
 
         self.mask_secret.assert_has_calls(expected_calls)
 
-    @parameterized.expand([(x,) for x in test_from_uri_params], 
UriTestCaseConfig.uri_test_name)
+    @pytest.mark.parametrize("test_config", [x for x in test_from_uri_params])
     def test_connection_get_uri_from_uri(self, test_config: UriTestCaseConfig):
         """
         This test verifies that when we create a conn_1 from URI, and we 
generate a URI from that conn, that
@@ -393,7 +390,7 @@ class TestConnection(unittest.TestCase):
         assert connection.schema == new_conn.schema
         assert connection.extra_dejson == new_conn.extra_dejson
 
-    @parameterized.expand([(x,) for x in test_from_uri_params], 
UriTestCaseConfig.uri_test_name)
+    @pytest.mark.parametrize("test_config", [x for x in test_from_uri_params])
     def test_connection_get_uri_from_conn(self, test_config: 
UriTestCaseConfig):
         """
         This test verifies that if we create conn_1 from attributes (rather 
than from URI), and we generate a
@@ -421,7 +418,8 @@ class TestConnection(unittest.TestCase):
             else:
                 assert actual_val == expected_val
 
-    @parameterized.expand(
+    @pytest.mark.parametrize(
+        "uri,uri_parts",
         [
             (
                 "http://:password@host:80/database";,
@@ -486,7 +484,7 @@ class TestConnection(unittest.TestCase):
                     schema="",
                 ),
             ),
-        ]
+        ],
     )
     def test_connection_from_with_auth_info(self, uri, uri_parts):
         connection = Connection(uri=uri)
@@ -498,46 +496,50 @@ class TestConnection(unittest.TestCase):
         assert connection.port == uri_parts.port
         assert connection.schema == uri_parts.schema
 
-    @parameterized.expand(
+    @pytest.mark.parametrize(
+        "extra,expected",
         [
             ('{"extra": null}', None),
             ('{"extra": "hi"}', "hi"),
             ('{"extra": {"yo": "hi"}}', '{"yo": "hi"}'),
             ('{"extra": "{\\"yo\\": \\"hi\\"}"}', '{"yo": "hi"}'),
-        ]
+        ],
     )
     def test_from_json_extra(self, extra, expected):
         """json serialization should support extra stored as object _or_ as 
string"""
         assert Connection.from_json(extra).extra == expected
 
-    @parameterized.expand(
+    @pytest.mark.parametrize(
+        "val,expected",
         [
             ('{"conn_type": "abc-abc"}', "abc_abc"),
             ('{"conn_type": "abc_abc"}', "abc_abc"),
             ('{"conn_type": "postgresql"}', "postgres"),
-        ]
+        ],
     )
     def test_from_json_conn_type(self, val, expected):
         """two conn_type normalizations are applied: replace - with _ and 
postgresql with postgres"""
         assert Connection.from_json(val).conn_type == expected
 
-    @parameterized.expand(
+    @pytest.mark.parametrize(
+        "val,expected",
         [
             ('{"port": 1}', 1),
             ('{"port": "1"}', 1),
             ('{"port": null}', None),
-        ]
+        ],
     )
     def test_from_json_port(self, val, expected):
         """two conn_type normalizations are applied: replace - with _ and 
postgresql with postgres"""
         assert Connection.from_json(val).port == expected
 
-    @parameterized.expand(
+    @pytest.mark.parametrize(
+        "val,expected",
         [
             ('pass :/!@#$%^&*(){}"', 'pass :/!@#$%^&*(){}"'),  # these are the 
same
             (None, None),
             ("", None),  # this is a consequence of the password getter
-        ]
+        ],
     )
     def test_from_json_special_characters(self, val, expected):
         """two conn_type normalizations are applied: replace - with _ and 
postgresql with postgres"""
diff --git a/tests/always/test_project_structure.py 
b/tests/always/test_project_structure.py
index f39d8b95fd..479e68fcbc 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -21,7 +21,6 @@ import glob
 import itertools
 import mmap
 import os
-import unittest
 
 import pytest
 
@@ -30,7 +29,7 @@ ROOT_FOLDER = os.path.realpath(
 )
 
 
-class TestProjectStructure(unittest.TestCase):
+class TestProjectStructure:
     def test_reference_to_providers_from_core(self):
         for filename in glob.glob(f"{ROOT_FOLDER}/example_dags/**/*.py", 
recursive=True):
             self.assert_file_not_contains(filename, "providers")
@@ -47,12 +46,12 @@ class TestProjectStructure(unittest.TestCase):
     def assert_file_not_contains(self, filename: str, pattern: str):
         with open(filename, "rb", 0) as file, mmap.mmap(file.fileno(), 0, 
access=mmap.ACCESS_READ) as content:
             if content.find(bytes(pattern, "utf-8")) != -1:
-                self.fail(f"File {filename} not contains pattern - {pattern}")
+                pytest.fail(f"File {filename} not contains pattern - 
{pattern}")
 
     def assert_file_contains(self, filename: str, pattern: str):
         with open(filename, "rb", 0) as file, mmap.mmap(file.fileno(), 0, 
access=mmap.ACCESS_READ) as content:
             if content.find(bytes(pattern, "utf-8")) == -1:
-                self.fail(f"File {filename} contains illegal pattern - 
{pattern}")
+                pytest.fail(f"File {filename} contains illegal pattern - 
{pattern}")
 
     def test_providers_modules_should_have_tests(self):
         """
@@ -90,8 +89,7 @@ class TestProjectStructure(unittest.TestCase):
 
         missing_tests_files = expected_test_files - 
expected_test_files.intersection(current_test_files)
 
-        with self.subTest("Detect missing tests in providers module"):
-            assert set() == missing_tests_files
+        assert set() == missing_tests_files, "Detect missing tests in 
providers module"
 
 
 def get_imports_from_file(filepath: str):
@@ -419,7 +417,7 @@ class 
TestDockerProviderProjectStructure(ExampleCoverageTest):
     PROVIDER = "docker"
 
 
-class TestOperatorsHooks(unittest.TestCase):
+class TestOperatorsHooks:
     def test_no_illegal_suffixes(self):
         illegal_suffixes = ["_operator.py", "_hook.py", "_sensor.py"]
         files = itertools.chain(

Reply via email to