This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b04be0f8dc Refactor: tmp_path in tests/models (#33560)
b04be0f8dc is described below

commit b04be0f8dcfee31b82b5231ced0f4153afde3387
Author: Miroslav Šedivý <[email protected]>
AuthorDate: Mon Aug 21 08:30:49 2023 +0000

    Refactor: tmp_path in tests/models (#33560)
---
 tests/models/test_dag.py    |  89 +++++++++-----------
 tests/models/test_dagbag.py | 195 ++++++++++++++++++++++----------------------
 2 files changed, 135 insertions(+), 149 deletions(-)

diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 88b8f238df..87835c8adc 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -28,7 +28,6 @@ import weakref
 from contextlib import redirect_stdout
 from datetime import timedelta
 from pathlib import Path
-from tempfile import NamedTemporaryFile
 from unittest import mock
 from unittest.mock import patch
 
@@ -642,37 +641,31 @@ class TestDag:
         jinja_env = dag.get_template_env(force_sandboxed=force_sandboxed)
         assert isinstance(jinja_env, expected_env)
 
-    def test_resolve_template_files_value(self):
-        with NamedTemporaryFile(suffix=".template") as f:
-            f.write(b"{{ ds }}")
-            f.flush()
-            template_dir = os.path.dirname(f.name)
-            template_file = os.path.basename(f.name)
+    def test_resolve_template_files_value(self, tmp_path):
+        path = tmp_path / "testfile.template"
+        path.write_text("{{ ds }}")
 
-            with DAG("test-dag", start_date=DEFAULT_DATE, 
template_searchpath=template_dir):
-                task = EmptyOperator(task_id="op1")
+        with DAG("test-dag", start_date=DEFAULT_DATE, 
template_searchpath=os.fspath(path.parent)):
+            task = EmptyOperator(task_id="op1")
 
-            task.test_field = template_file
-            task.template_fields = ("test_field",)
-            task.template_ext = (".template",)
-            task.resolve_template_files()
+        task.test_field = path.name
+        task.template_fields = ("test_field",)
+        task.template_ext = (".template",)
+        task.resolve_template_files()
 
         assert task.test_field == "{{ ds }}"
 
-    def test_resolve_template_files_list(self):
-        with NamedTemporaryFile(suffix=".template") as f:
-            f.write(b"{{ ds }}")
-            f.flush()
-            template_dir = os.path.dirname(f.name)
-            template_file = os.path.basename(f.name)
+    def test_resolve_template_files_list(self, tmp_path):
+        path = tmp_path / "testfile.template"
+        path.write_text("{{ ds }}")
 
-            with DAG("test-dag", start_date=DEFAULT_DATE, 
template_searchpath=template_dir):
-                task = EmptyOperator(task_id="op1")
+        with DAG("test-dag", start_date=DEFAULT_DATE, 
template_searchpath=os.fspath(path.parent)):
+            task = EmptyOperator(task_id="op1")
 
-            task.test_field = [template_file, "some_string"]
-            task.template_fields = ("test_field",)
-            task.template_ext = (".template",)
-            task.resolve_template_files()
+        task.test_field = [path.name, "some_string"]
+        task.template_fields = ("test_field",)
+        task.template_ext = (".template",)
+        task.resolve_template_files()
 
         assert task.test_field == ["{{ ds }}", "some_string"]
 
@@ -1988,7 +1981,7 @@ class TestDag:
         dag.test()
         mock_object.assert_called_with([0, 1, 2, 3, 4])
 
-    def test_dag_connection_file(self):
+    def test_dag_connection_file(self, tmp_path):
         test_connections_string = """
 ---
 my_postgres_conn:
@@ -2008,10 +2001,9 @@ my_postgres_conn:
 
         with dag:
             check_task()
-        with NamedTemporaryFile(suffix=".yaml") as tmp:
-            with open(tmp.name, "w") as f:
-                f.write(test_connections_string)
-            dag.test(conn_file_path=tmp.name)
+        path = tmp_path / "testfile.yaml"
+        path.write_text(test_connections_string)
+        dag.test(conn_file_path=os.fspath(path))
 
     def _make_test_subdag(self, session):
         dag_id = "test_subdag"
@@ -2888,31 +2880,28 @@ class TestDagDecorator:
         assert dag.dag_id == "noop_pipeline"
         assert "Regular DAG documentation" in dag.doc_md
 
-    def test_resolve_documentation_template_file_rendered(self):
+    def test_resolve_documentation_template_file_rendered(self, tmp_path):
         """Test that @dag uses function docs as doc_md for DAG object"""
 
-        with NamedTemporaryFile(suffix=".md") as f:
-            f.write(
-                b"""
-            {% if True %}
-               External Markdown DAG documentation
-            {% endif %}
+        path = tmp_path / "testfile.md"
+        path.write_text(
             """
-            )
-            f.flush()
-            template_dir = os.path.dirname(f.name)
-            template_file = os.path.basename(f.name)
+        {% if True %}
+            External Markdown DAG documentation
+        {% endif %}
+        """
+        )
 
-            @dag_decorator(
-                "test-dag", start_date=DEFAULT_DATE, 
template_searchpath=template_dir, doc_md=template_file
-            )
-            def markdown_docs():
-                ...
+        @dag_decorator(
+            "test-dag", start_date=DEFAULT_DATE, 
template_searchpath=os.fspath(path.parent), doc_md=path.name
+        )
+        def markdown_docs():
+            ...
 
-            dag = markdown_docs()
-            assert isinstance(dag, DAG)
-            assert dag.dag_id == "test-dag"
-            assert dag.doc_md.strip() == "External Markdown DAG documentation"
+        dag = markdown_docs()
+        assert isinstance(dag, DAG)
+        assert dag.dag_id == "test-dag"
+        assert dag.doc_md.strip() == "External Markdown DAG documentation"
 
     def test_fails_if_arg_not_set(self):
         """Test that @dag decorated function fails if positional argument is 
not set"""
diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py
index 22ff525bb8..b9e5b5f2f1 100644
--- a/tests/models/test_dagbag.py
+++ b/tests/models/test_dagbag.py
@@ -20,13 +20,11 @@ import inspect
 import logging
 import os
 import pathlib
-import shutil
 import sys
 import textwrap
 import zipfile
 from copy import deepcopy
 from datetime import datetime, timedelta, timezone
-from tempfile import NamedTemporaryFile, mkdtemp
 from typing import Iterator
 from unittest import mock
 from unittest.mock import patch
@@ -63,18 +61,16 @@ def db_clean_up():
 
 class TestDagBag:
     def setup_class(self):
-        self.empty_dir = mkdtemp()
         db_clean_up()
 
     def teardown_class(self):
-        shutil.rmtree(self.empty_dir)
         db_clean_up()
 
-    def test_get_existing_dag(self):
+    def test_get_existing_dag(self, tmp_path):
         """
         Test that we're able to parse some example DAGs and retrieve them
         """
-        dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=True)
+        dagbag = models.DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=True)
 
         some_expected_dag_ids = ["example_bash_operator", 
"example_branch_operator"]
 
@@ -86,78 +82,77 @@ class TestDagBag:
 
         assert dagbag.size() >= 7
 
-    def test_get_non_existing_dag(self):
+    def test_get_non_existing_dag(self, tmp_path):
         """
         test that retrieving a non existing dag id returns None without 
crashing
         """
-        dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=False)
+        dagbag = models.DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False)
 
         non_existing_dag_id = "non_existing_dag_id"
         assert dagbag.get_dag(non_existing_dag_id) is None
 
-    def test_serialized_dag_not_existing_doesnt_raise(self):
+    def test_serialized_dag_not_existing_doesnt_raise(self, tmp_path):
         """
         test that retrieving a non existing dag id returns None without 
crashing
         """
-        dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=False, read_dags_from_db=True)
+        dagbag = models.DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False, read_dags_from_db=True)
 
         non_existing_dag_id = "non_existing_dag_id"
         assert dagbag.get_dag(non_existing_dag_id) is None
 
-    def test_dont_load_example(self):
+    def test_dont_load_example(self, tmp_path):
         """
         test that the example are not loaded
         """
-        dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=False)
+        dagbag = models.DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False)
 
         assert dagbag.size() == 0
 
-    def test_safe_mode_heuristic_match(self):
+    def test_safe_mode_heuristic_match(self, tmp_path):
         """With safe mode enabled, a file matching the discovery heuristics
         should be discovered.
         """
-        with NamedTemporaryFile(dir=self.empty_dir, suffix=".py") as f:
-            f.write(b"# airflow")
-            f.write(b"# DAG")
-            f.flush()
+        path = tmp_path / "testfile.py"
+        path.write_text("# airflow\n# DAG")
 
-            with conf_vars({("core", "dags_folder"): self.empty_dir}):
-                dagbag = models.DagBag(include_examples=False, safe_mode=True)
+        with conf_vars({("core", "dags_folder"): os.fspath(path.parent)}):
+            dagbag = models.DagBag(include_examples=False, safe_mode=True)
 
-            assert len(dagbag.dagbag_stats) == 1
-            assert dagbag.dagbag_stats[0].file == 
f"/{os.path.basename(f.name)}"
+        assert len(dagbag.dagbag_stats) == 1
+        assert dagbag.dagbag_stats[0].file == f"/{path.name}"
 
-    def test_safe_mode_heuristic_mismatch(self):
+    def test_safe_mode_heuristic_mismatch(self, tmp_path):
         """With safe mode enabled, a file not matching the discovery heuristics
         should not be discovered.
         """
-        with NamedTemporaryFile(dir=self.empty_dir, suffix=".py"):
-            with conf_vars({("core", "dags_folder"): self.empty_dir}):
-                dagbag = models.DagBag(include_examples=False, safe_mode=True)
-            assert len(dagbag.dagbag_stats) == 0
+        path = tmp_path / "testfile.py"
+        path.write_text("")
+        with conf_vars({("core", "dags_folder"): os.fspath(path.parent)}):
+            dagbag = models.DagBag(include_examples=False, safe_mode=True)
+        assert len(dagbag.dagbag_stats) == 0
 
-    def test_safe_mode_disabled(self):
+    def test_safe_mode_disabled(self, tmp_path):
         """With safe mode disabled, an empty python file should be 
discovered."""
-        with NamedTemporaryFile(dir=self.empty_dir, suffix=".py") as f:
-            with conf_vars({("core", "dags_folder"): self.empty_dir}):
-                dagbag = models.DagBag(include_examples=False, safe_mode=False)
-            assert len(dagbag.dagbag_stats) == 1
-            assert dagbag.dagbag_stats[0].file == 
f"/{os.path.basename(f.name)}"
-
-    def test_process_file_that_contains_multi_bytes_char(self):
+        path = tmp_path / "testfile.py"
+        path.write_text("")
+        with conf_vars({("core", "dags_folder"): os.fspath(path.parent)}):
+            dagbag = models.DagBag(include_examples=False, safe_mode=False)
+        assert len(dagbag.dagbag_stats) == 1
+        assert dagbag.dagbag_stats[0].file == f"/{path.name}"
+
+    def test_process_file_that_contains_multi_bytes_char(self, tmp_path):
         """
         test that we're able to parse file that contains multi-byte char
         """
-        with NamedTemporaryFile() as f:
-            f.write("\u3042".encode())  # write multi-byte char (hiragana)
-            f.flush()
+        path = tmp_path / "testfile"
+        path.write_text("\u3042")  # write multi-byte char (hiragana)
 
-            dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=False)
-            assert [] == dagbag.process_file(f.name)
+        dagbag = models.DagBag(dag_folder=os.fspath(path.parent), 
include_examples=False)
+        assert [] == dagbag.process_file(os.fspath(path))
 
-    def test_process_file_duplicated_dag_id(self):
+    def test_process_file_duplicated_dag_id(self, tmp_path):
         """Loading a DAG with ID that already existed in a DAG bag should 
result in an import error."""
-        dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=False)
+        dagbag = models.DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False)
 
         def create_dag():
             from airflow.decorators import dag
@@ -169,23 +164,22 @@ class TestDagBag:
             my_dag = my_flow()  # noqa
 
         source_lines = [line[12:] for line in 
inspect.getsource(create_dag).splitlines(keepends=True)[1:]]
-        with NamedTemporaryFile("w+", encoding="utf8") as tf_1, 
NamedTemporaryFile(
-            "w+", encoding="utf8"
-        ) as tf_2:
-            tf_1.writelines(source_lines)
-            tf_2.writelines(source_lines)
-            tf_1.flush()
-            tf_2.flush()
-
-            found_1 = dagbag.process_file(tf_1.name)
-            assert len(found_1) == 1 and found_1[0].dag_id == "my_flow"
-            assert dagbag.import_errors == {}
-            dags_in_bag = dagbag.dags
-
-            found_2 = dagbag.process_file(tf_2.name)
-            assert len(found_2) == 0
-            assert 
dagbag.import_errors[tf_2.name].startswith("AirflowDagDuplicatedIdException: 
Ignoring DAG")
-            assert dagbag.dags == dags_in_bag  # Should not change.
+        path1 = tmp_path / "testfile1"
+        path2 = tmp_path / "testfile2"
+        path1.write_text("".join(source_lines))
+        path2.write_text("".join(source_lines))
+
+        found_1 = dagbag.process_file(os.fspath(path1))
+        assert len(found_1) == 1 and found_1[0].dag_id == "my_flow"
+        assert dagbag.import_errors == {}
+        dags_in_bag = dagbag.dags
+
+        found_2 = dagbag.process_file(os.fspath(path2))
+        assert len(found_2) == 0
+        assert dagbag.import_errors[os.fspath(path2)].startswith(
+            "AirflowDagDuplicatedIdException: Ignoring DAG"
+        )
+        assert dagbag.dags == dags_in_bag  # Should not change.
 
     def test_zip_skip_log(self, caplog):
         """
@@ -202,37 +196,39 @@ class TestDagBag:
             "assumed to contain no DAGs. Skipping." in caplog.text
         )
 
-    def test_zip(self):
+    def test_zip(self, tmp_path):
         """
         test the loading of a DAG within a zip file that includes dependencies
         """
         syspath_before = deepcopy(sys.path)
-        dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=False)
+        dagbag = models.DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False)
         dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip"))
         assert dagbag.get_dag("test_zip_dag")
         assert sys.path == syspath_before  # sys.path doesn't change
 
     @patch("airflow.models.dagbag.timeout")
     @patch("airflow.models.dagbag.settings.get_dagbag_import_timeout")
-    def test_process_dag_file_without_timeout(self, 
mocked_get_dagbag_import_timeout, mocked_timeout):
+    def test_process_dag_file_without_timeout(
+        self, mocked_get_dagbag_import_timeout, mocked_timeout, tmp_path
+    ):
         """
         Test dag file parsing without timeout
         """
         mocked_get_dagbag_import_timeout.return_value = 0
 
-        dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=False)
+        dagbag = models.DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False)
         dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, 
"test_default_views.py"))
         mocked_timeout.assert_not_called()
 
         mocked_get_dagbag_import_timeout.return_value = -1
-        dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=False)
+        dagbag = models.DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False)
         dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, 
"test_default_views.py"))
         mocked_timeout.assert_not_called()
 
     @patch("airflow.models.dagbag.timeout")
     @patch("airflow.models.dagbag.settings.get_dagbag_import_timeout")
     def test_process_dag_file_with_non_default_timeout(
-        self, mocked_get_dagbag_import_timeout, mocked_timeout
+        self, mocked_get_dagbag_import_timeout, mocked_timeout, tmp_path
     ):
         """
         Test customized dag file parsing timeout
@@ -243,19 +239,21 @@ class TestDagBag:
         # ensure the test value is not equal to the default value
         assert timeout_value != settings.conf.getfloat("core", 
"DAGBAG_IMPORT_TIMEOUT")
 
-        dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=False)
+        dagbag = models.DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False)
         dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, 
"test_default_views.py"))
 
         mocked_timeout.assert_called_once_with(timeout_value, 
error_message=mock.ANY)
 
     @patch("airflow.models.dagbag.settings.get_dagbag_import_timeout")
-    def test_check_value_type_from_get_dagbag_import_timeout(self, 
mocked_get_dagbag_import_timeout):
+    def test_check_value_type_from_get_dagbag_import_timeout(
+        self, mocked_get_dagbag_import_timeout, tmp_path
+    ):
         """
         Test correctness of value from get_dagbag_import_timeout
         """
         mocked_get_dagbag_import_timeout.return_value = "1"
 
-        dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=False)
+        dagbag = models.DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False)
         with pytest.raises(
             TypeError, match=r"Value \(1\) from get_dagbag_import_timeout must 
be int or float"
         ):
@@ -267,27 +265,28 @@ class TestDagBag:
 
     @pytest.fixture()
     def invalid_cron_zipped_dag(self, invalid_cron_dag: str, tmp_path: 
pathlib.Path) -> Iterator[str]:
-        zipped = os.path.join(tmp_path, "test_zip_invalid_cron.zip")
+        zipped = tmp_path / "test_zip_invalid_cron.zip"
         with zipfile.ZipFile(zipped, "w") as zf:
             zf.write(invalid_cron_dag, os.path.basename(invalid_cron_dag))
-        yield zipped
-        os.unlink(zipped)
+        yield os.fspath(zipped)
 
     @pytest.mark.parametrize("invalid_dag_name", ["invalid_cron_dag", 
"invalid_cron_zipped_dag"])
-    def test_process_file_cron_validity_check(self, request: 
pytest.FixtureRequest, invalid_dag_name: str):
+    def test_process_file_cron_validity_check(
+        self, request: pytest.FixtureRequest, invalid_dag_name: str, tmp_path
+    ):
         """test if an invalid cron expression as schedule interval can be 
identified"""
-        dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=False)
+        dagbag = models.DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False)
         assert len(dagbag.import_errors) == 0
         dagbag.process_file(request.getfixturevalue(invalid_dag_name))
         assert len(dagbag.import_errors) == 1
         assert len(dagbag.dags) == 0
 
-    def test_process_file_invalid_param_check(self):
+    def test_process_file_invalid_param_check(self, tmp_path):
         """
         test if an invalid param in the dag param can be identified
         """
         invalid_dag_files = ["test_invalid_param.py"]
-        dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=False)
+        dagbag = models.DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False)
 
         assert len(dagbag.import_errors) == 0
         for file in invalid_dag_files:
@@ -351,7 +350,7 @@ class TestDagBag:
     )
     def test_get_dag_registration(self, file_to_load, expected):
         dagbag = models.DagBag(dag_folder=os.devnull, include_examples=False)
-        dagbag.process_file(str(file_to_load))
+        dagbag.process_file(os.fspath(file_to_load))
         for dag_id, path in expected.items():
             dag = dagbag.get_dag(dag_id)
             assert dag, f"{dag_id} was bagged"
@@ -366,12 +365,11 @@ class TestDagBag:
     def zip_with_valid_dag_and_dup_tasks(self, tmp_path: pathlib.Path) -> 
Iterator[str]:
         failing_dag_file = TEST_DAGS_FOLDER / "test_invalid_dup_task.py"
         working_dag_file = TEST_DAGS_FOLDER / "test_example_bash_operator.py"
-        zipped = os.path.join(tmp_path, "test_zip_invalid_dup_task.zip")
+        zipped = tmp_path / "test_zip_invalid_dup_task.zip"
         with zipfile.ZipFile(zipped, "w") as zf:
-            zf.write(failing_dag_file, os.path.basename(failing_dag_file))
-            zf.write(working_dag_file, os.path.basename(working_dag_file))
-        yield zipped
-        os.unlink(zipped)
+            zf.write(failing_dag_file, failing_dag_file.name)
+            zf.write(working_dag_file, working_dag_file.name)
+        yield os.fspath(zipped)
 
     def test_dag_registration_with_failure_zipped(self, 
zip_with_valid_dag_and_dup_tasks):
         dagbag = models.DagBag(dag_folder=os.devnull, include_examples=False)
@@ -380,7 +378,7 @@ class TestDagBag:
         assert ["test_example_bash_operator"] == [dag.dag_id for dag in found]
 
     @patch.object(DagModel, "get_current")
-    def test_refresh_py_dag(self, mock_dagmodel):
+    def test_refresh_py_dag(self, mock_dagmodel, tmp_path):
         """
         Test that we can refresh an ordinary .py DAG
         """
@@ -400,7 +398,7 @@ class TestDagBag:
                     _TestDagBag.process_file_calls += 1
                 return super().process_file(filepath, only_if_updated, 
safe_mode)
 
-        dagbag = _TestDagBag(dag_folder=self.empty_dir, include_examples=True)
+        dagbag = _TestDagBag(dag_folder=os.fspath(tmp_path), 
include_examples=True)
 
         assert 1 == dagbag.process_file_calls
         dag = dagbag.get_dag(dag_id)
@@ -436,7 +434,7 @@ class TestDagBag:
         assert dag_id == dag.dag_id
         assert 2 == dagbag.process_file_calls
 
-    def test_dag_removed_if_serialized_dag_is_removed(self, dag_maker):
+    def test_dag_removed_if_serialized_dag_is_removed(self, dag_maker, 
tmp_path):
         """
         Test that if a DAG does not exist in serialized_dag table (as the DAG 
file was removed),
         remove dags from the DagBag
@@ -450,7 +448,7 @@ class TestDagBag:
         ) as dag:
             EmptyOperator(task_id="task_1")
         dag_maker.create_dagrun()
-        dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False, 
read_dags_from_db=True)
+        dagbag = DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False, read_dags_from_db=True)
         dagbag.dags = {dag.dag_id: 
SerializedDAG.from_dict(SerializedDAG.to_dict(dag))}
         dagbag.dags_last_fetched = {dag.dag_id: (tz.utcnow() - 
timedelta(minutes=2))}
         dagbag.dags_hash = {dag.dag_id: mock.ANY}
@@ -462,19 +460,18 @@ class TestDagBag:
         assert dag.dag_id not in dagbag.dags_last_fetched
         assert dag.dag_id not in dagbag.dags_hash
 
-    def process_dag(self, create_dag):
+    def process_dag(self, create_dag, tmp_path):
         """
         Helper method to process a file generated from the input create_dag 
function.
         """
         # write source to file
         source = 
textwrap.dedent("".join(inspect.getsource(create_dag).splitlines(True)[1:-1]))
-        with NamedTemporaryFile() as f:
-            f.write(source.encode("utf8"))
-            f.flush()
+        path = tmp_path / "testfile"
+        path.write_text(source)
 
-            dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=False)
-            found_dags = dagbag.process_file(f.name)
-            return dagbag, found_dags, f.name
+        dagbag = models.DagBag(dag_folder=os.fspath(path.parent), 
include_examples=False)
+        found_dags = dagbag.process_file(os.fspath(path))
+        return dagbag, found_dags, os.fspath(path)
 
     def validate_dags(self, expected_parent_dag, actual_found_dags, 
actual_dagbag, should_be_found=True):
         expected_dag_ids = list(map(lambda dag: dag.dag_id, 
expected_parent_dag.subdags))
@@ -493,7 +490,7 @@ class TestDagBag:
                 f'be in dagbag.dags after processing dag 
"{expected_parent_dag.dag_id}"'
             )
 
-    def test_load_subdags(self):
+    def test_load_subdags(self, tmp_path):
         # Define Dag to load
         def standard_subdag():
             import datetime
@@ -539,7 +536,7 @@ class TestDagBag:
         assert len(test_dag.subdags) == 2
 
         # Perform processing dag
-        dagbag, found_dags, _ = self.process_dag(standard_subdag)
+        dagbag, found_dags, _ = self.process_dag(standard_subdag, tmp_path)
 
         # Validate correctness
         # all dags from test_dag should be listed
@@ -623,7 +620,7 @@ class TestDagBag:
         assert len(test_dag.subdags) == 6
 
         # Perform processing dag
-        dagbag, found_dags, filename = self.process_dag(nested_subdags)
+        dagbag, found_dags, filename = self.process_dag(nested_subdags, 
tmp_path)
 
         # Validate correctness
         # all dags from test_dag should be listed
@@ -632,7 +629,7 @@ class TestDagBag:
         for dag in dagbag.dags.values():
             assert dag.fileloc == filename
 
-    def test_skip_cycle_dags(self):
+    def test_skip_cycle_dags(self, tmp_path):
         """
         Don't crash when loading an invalid (contains a cycle) DAG file.
         Don't load the dag into the DagBag either
@@ -661,7 +658,7 @@ class TestDagBag:
         assert len(test_dag.subdags) == 0
 
         # Perform processing dag
-        dagbag, found_dags, file_path = self.process_dag(basic_cycle)
+        dagbag, found_dags, file_path = self.process_dag(basic_cycle, tmp_path)
 
         # #Validate correctness
         # None of the dags should be found
@@ -748,18 +745,18 @@ class TestDagBag:
         assert len(test_dag.subdags) == 6
 
         # Perform processing dag
-        dagbag, found_dags, file_path = self.process_dag(nested_subdag_cycle)
+        dagbag, found_dags, file_path = self.process_dag(nested_subdag_cycle, 
tmp_path)
 
         # Validate correctness
         # None of the dags should be found
         self.validate_dags(test_dag, found_dags, dagbag, should_be_found=False)
         assert file_path in dagbag.import_errors
 
-    def test_process_file_with_none(self):
+    def test_process_file_with_none(self, tmp_path):
         """
         test that process_file can handle Nones
         """
-        dagbag = models.DagBag(dag_folder=self.empty_dir, 
include_examples=False)
+        dagbag = models.DagBag(dag_folder=os.fspath(tmp_path), 
include_examples=False)
 
         assert [] == dagbag.process_file(None)
 

Reply via email to