This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3f77152b10 Refactor: use tmp_path in mysql test (#33544)
3f77152b10 is described below

commit 3f77152b104424a6fd626b4673e0984626fcd628
Author: Miroslav Šedivý <[email protected]>
AuthorDate: Sun Aug 20 21:53:41 2023 +0000

    Refactor: use tmp_path in mysql test (#33544)
---
 tests/providers/mysql/hooks/test_mysql.py     | 41 ++++++++++++---------------
 tests/providers/mysql/operators/test_mysql.py | 17 ++++-------
 2 files changed, 24 insertions(+), 34 deletions(-)

diff --git a/tests/providers/mysql/hooks/test_mysql.py 
b/tests/providers/mysql/hooks/test_mysql.py
index 2b6c81df5d..b4de3ce20f 100644
--- a/tests/providers/mysql/hooks/test_mysql.py
+++ b/tests/providers/mysql/hooks/test_mysql.py
@@ -361,31 +361,26 @@ class TestMySql:
             "AIRFLOW_CONN_AIRFLOW_DB": 
"mysql://root@mysql/airflow?charset=utf8mb4",
         },
     )
-    def test_mysql_hook_test_bulk_load(self, client):
+    def test_mysql_hook_test_bulk_load(self, client, tmp_path):
         with MySqlContext(client):
             records = ("foo", "bar", "baz")
-
-            import tempfile
-
-            with tempfile.NamedTemporaryFile() as f:
-                f.write("\n".join(records).encode("utf8"))
-                f.flush()
-
-                hook = MySqlHook("airflow_db", local_infile=True)
-                with closing(hook.get_conn()) as conn:
-                    with closing(conn.cursor()) as cursor:
-                        cursor.execute(
-                            """
-                            CREATE TABLE IF NOT EXISTS test_airflow (
-                                dummy VARCHAR(50)
-                            )
-                        """
-                        )
-                        cursor.execute("TRUNCATE TABLE test_airflow")
-                        hook.bulk_load("test_airflow", f.name)
-                        cursor.execute("SELECT dummy FROM test_airflow")
-                        results = tuple(result[0] for result in 
cursor.fetchall())
-                        assert sorted(results) == sorted(records)
+            path = tmp_path / "testfile"
+            path.write_text("\n".join(records))
+
+            hook = MySqlHook("airflow_db", local_infile=True)
+            with closing(hook.get_conn()) as conn, closing(conn.cursor()) as 
cursor:
+                cursor.execute(
+                    """
+                    CREATE TABLE IF NOT EXISTS test_airflow (
+                        dummy VARCHAR(50)
+                    )
+                """
+                )
+                cursor.execute("TRUNCATE TABLE test_airflow")
+                hook.bulk_load("test_airflow", os.fspath(path))
+                cursor.execute("SELECT dummy FROM test_airflow")
+                results = tuple(result[0] for result in cursor.fetchall())
+                assert sorted(results) == sorted(records)
 
     @pytest.mark.parametrize("client", ["mysqlclient", 
"mysql-connector-python"])
     def test_mysql_hook_test_bulk_dump(self, client):
diff --git a/tests/providers/mysql/operators/test_mysql.py 
b/tests/providers/mysql/operators/test_mysql.py
index 44aed62917..04ead90f16 100644
--- a/tests/providers/mysql/operators/test_mysql.py
+++ b/tests/providers/mysql/operators/test_mysql.py
@@ -19,7 +19,6 @@ from __future__ import annotations
 
 import os
 from contextlib import closing
-from tempfile import NamedTemporaryFile
 from unittest.mock import MagicMock
 
 import pytest
@@ -100,18 +99,14 @@ class TestMySql:
             except OperationalError as e:
                 assert "Unknown database 'foobar'" in str(e)
 
-    def test_mysql_operator_resolve_parameters_template_json_file(self):
+    def test_mysql_operator_resolve_parameters_template_json_file(self, 
tmp_path):
+        path = tmp_path / "testfile.json"
+        path.write_text('{\n "foo": "{{ ds }}"}')
 
-        with NamedTemporaryFile(suffix=".json") as f:
-            f.write(b'{\n "foo": "{{ ds }}"}')
-            f.flush()
-            template_dir = os.path.dirname(f.name)
-            template_file = os.path.basename(f.name)
+        with DAG("test-dag", start_date=DEFAULT_DATE, 
template_searchpath=os.fspath(path.parent)):
+            task = MySqlOperator(task_id="op1", parameters=path.name, 
sql="SELECT 1")
 
-            with DAG("test-dag", start_date=DEFAULT_DATE, 
template_searchpath=template_dir):
-                task = MySqlOperator(task_id="op1", parameters=template_file, 
sql="SELECT 1")
-
-            task.resolve_template_files()
+        task.resolve_template_files()
 
         assert isinstance(task.parameters, dict)
         assert task.parameters["foo"] == "{{ ds }}"

Reply via email to