This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b04be0f8dc Refactor: tmp_path in tests/models (#33560)
b04be0f8dc is described below
commit b04be0f8dcfee31b82b5231ced0f4153afde3387
Author: Miroslav Šedivý <[email protected]>
AuthorDate: Mon Aug 21 08:30:49 2023 +0000
Refactor: tmp_path in tests/models (#33560)
---
tests/models/test_dag.py | 89 +++++++++-----------
tests/models/test_dagbag.py | 195 ++++++++++++++++++++++----------------------
2 files changed, 135 insertions(+), 149 deletions(-)
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 88b8f238df..87835c8adc 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -28,7 +28,6 @@ import weakref
from contextlib import redirect_stdout
from datetime import timedelta
from pathlib import Path
-from tempfile import NamedTemporaryFile
from unittest import mock
from unittest.mock import patch
@@ -642,37 +641,31 @@ class TestDag:
jinja_env = dag.get_template_env(force_sandboxed=force_sandboxed)
assert isinstance(jinja_env, expected_env)
- def test_resolve_template_files_value(self):
- with NamedTemporaryFile(suffix=".template") as f:
- f.write(b"{{ ds }}")
- f.flush()
- template_dir = os.path.dirname(f.name)
- template_file = os.path.basename(f.name)
+ def test_resolve_template_files_value(self, tmp_path):
+ path = tmp_path / "testfile.template"
+ path.write_text("{{ ds }}")
- with DAG("test-dag", start_date=DEFAULT_DATE,
template_searchpath=template_dir):
- task = EmptyOperator(task_id="op1")
+ with DAG("test-dag", start_date=DEFAULT_DATE,
template_searchpath=os.fspath(path.parent)):
+ task = EmptyOperator(task_id="op1")
- task.test_field = template_file
- task.template_fields = ("test_field",)
- task.template_ext = (".template",)
- task.resolve_template_files()
+ task.test_field = path.name
+ task.template_fields = ("test_field",)
+ task.template_ext = (".template",)
+ task.resolve_template_files()
assert task.test_field == "{{ ds }}"
- def test_resolve_template_files_list(self):
- with NamedTemporaryFile(suffix=".template") as f:
- f.write(b"{{ ds }}")
- f.flush()
- template_dir = os.path.dirname(f.name)
- template_file = os.path.basename(f.name)
+ def test_resolve_template_files_list(self, tmp_path):
+ path = tmp_path / "testfile.template"
+ path.write_text("{{ ds }}")
- with DAG("test-dag", start_date=DEFAULT_DATE,
template_searchpath=template_dir):
- task = EmptyOperator(task_id="op1")
+ with DAG("test-dag", start_date=DEFAULT_DATE,
template_searchpath=os.fspath(path.parent)):
+ task = EmptyOperator(task_id="op1")
- task.test_field = [template_file, "some_string"]
- task.template_fields = ("test_field",)
- task.template_ext = (".template",)
- task.resolve_template_files()
+ task.test_field = [path.name, "some_string"]
+ task.template_fields = ("test_field",)
+ task.template_ext = (".template",)
+ task.resolve_template_files()
assert task.test_field == ["{{ ds }}", "some_string"]
@@ -1988,7 +1981,7 @@ class TestDag:
dag.test()
mock_object.assert_called_with([0, 1, 2, 3, 4])
- def test_dag_connection_file(self):
+ def test_dag_connection_file(self, tmp_path):
test_connections_string = """
---
my_postgres_conn:
@@ -2008,10 +2001,9 @@ my_postgres_conn:
with dag:
check_task()
- with NamedTemporaryFile(suffix=".yaml") as tmp:
- with open(tmp.name, "w") as f:
- f.write(test_connections_string)
- dag.test(conn_file_path=tmp.name)
+ path = tmp_path / "testfile.yaml"
+ path.write_text(test_connections_string)
+ dag.test(conn_file_path=os.fspath(path))
def _make_test_subdag(self, session):
dag_id = "test_subdag"
@@ -2888,31 +2880,28 @@ class TestDagDecorator:
assert dag.dag_id == "noop_pipeline"
assert "Regular DAG documentation" in dag.doc_md
- def test_resolve_documentation_template_file_rendered(self):
+ def test_resolve_documentation_template_file_rendered(self, tmp_path):
"""Test that @dag uses function docs as doc_md for DAG object"""
- with NamedTemporaryFile(suffix=".md") as f:
- f.write(
- b"""
- {% if True %}
- External Markdown DAG documentation
- {% endif %}
+ path = tmp_path / "testfile.md"
+ path.write_text(
"""
- )
- f.flush()
- template_dir = os.path.dirname(f.name)
- template_file = os.path.basename(f.name)
+ {% if True %}
+ External Markdown DAG documentation
+ {% endif %}
+ """
+ )
- @dag_decorator(
- "test-dag", start_date=DEFAULT_DATE,
template_searchpath=template_dir, doc_md=template_file
- )
- def markdown_docs():
- ...
+ @dag_decorator(
+ "test-dag", start_date=DEFAULT_DATE,
template_searchpath=os.fspath(path.parent), doc_md=path.name
+ )
+ def markdown_docs():
+ ...
- dag = markdown_docs()
- assert isinstance(dag, DAG)
- assert dag.dag_id == "test-dag"
- assert dag.doc_md.strip() == "External Markdown DAG documentation"
+ dag = markdown_docs()
+ assert isinstance(dag, DAG)
+ assert dag.dag_id == "test-dag"
+ assert dag.doc_md.strip() == "External Markdown DAG documentation"
def test_fails_if_arg_not_set(self):
"""Test that @dag decorated function fails if positional argument is
not set"""
diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py
index 22ff525bb8..b9e5b5f2f1 100644
--- a/tests/models/test_dagbag.py
+++ b/tests/models/test_dagbag.py
@@ -20,13 +20,11 @@ import inspect
import logging
import os
import pathlib
-import shutil
import sys
import textwrap
import zipfile
from copy import deepcopy
from datetime import datetime, timedelta, timezone
-from tempfile import NamedTemporaryFile, mkdtemp
from typing import Iterator
from unittest import mock
from unittest.mock import patch
@@ -63,18 +61,16 @@ def db_clean_up():
class TestDagBag:
def setup_class(self):
- self.empty_dir = mkdtemp()
db_clean_up()
def teardown_class(self):
- shutil.rmtree(self.empty_dir)
db_clean_up()
- def test_get_existing_dag(self):
+ def test_get_existing_dag(self, tmp_path):
"""
Test that we're able to parse some example DAGs and retrieve them
"""
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=True)
+ dagbag = models.DagBag(dag_folder=os.fspath(tmp_path),
include_examples=True)
some_expected_dag_ids = ["example_bash_operator",
"example_branch_operator"]
@@ -86,78 +82,77 @@ class TestDagBag:
assert dagbag.size() >= 7
- def test_get_non_existing_dag(self):
+ def test_get_non_existing_dag(self, tmp_path):
"""
test that retrieving a non existing dag id returns None without
crashing
"""
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=False)
+ dagbag = models.DagBag(dag_folder=os.fspath(tmp_path),
include_examples=False)
non_existing_dag_id = "non_existing_dag_id"
assert dagbag.get_dag(non_existing_dag_id) is None
- def test_serialized_dag_not_existing_doesnt_raise(self):
+ def test_serialized_dag_not_existing_doesnt_raise(self, tmp_path):
"""
test that retrieving a non existing dag id returns None without
crashing
"""
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=False, read_dags_from_db=True)
+ dagbag = models.DagBag(dag_folder=os.fspath(tmp_path),
include_examples=False, read_dags_from_db=True)
non_existing_dag_id = "non_existing_dag_id"
assert dagbag.get_dag(non_existing_dag_id) is None
- def test_dont_load_example(self):
+ def test_dont_load_example(self, tmp_path):
"""
test that the example are not loaded
"""
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=False)
+ dagbag = models.DagBag(dag_folder=os.fspath(tmp_path),
include_examples=False)
assert dagbag.size() == 0
- def test_safe_mode_heuristic_match(self):
+ def test_safe_mode_heuristic_match(self, tmp_path):
"""With safe mode enabled, a file matching the discovery heuristics
should be discovered.
"""
- with NamedTemporaryFile(dir=self.empty_dir, suffix=".py") as f:
- f.write(b"# airflow")
- f.write(b"# DAG")
- f.flush()
+ path = tmp_path / "testfile.py"
+ path.write_text("# airflow\n# DAG")
- with conf_vars({("core", "dags_folder"): self.empty_dir}):
- dagbag = models.DagBag(include_examples=False, safe_mode=True)
+ with conf_vars({("core", "dags_folder"): os.fspath(path.parent)}):
+ dagbag = models.DagBag(include_examples=False, safe_mode=True)
- assert len(dagbag.dagbag_stats) == 1
- assert dagbag.dagbag_stats[0].file ==
f"/{os.path.basename(f.name)}"
+ assert len(dagbag.dagbag_stats) == 1
+ assert dagbag.dagbag_stats[0].file == f"/{path.name}"
- def test_safe_mode_heuristic_mismatch(self):
+ def test_safe_mode_heuristic_mismatch(self, tmp_path):
"""With safe mode enabled, a file not matching the discovery heuristics
should not be discovered.
"""
- with NamedTemporaryFile(dir=self.empty_dir, suffix=".py"):
- with conf_vars({("core", "dags_folder"): self.empty_dir}):
- dagbag = models.DagBag(include_examples=False, safe_mode=True)
- assert len(dagbag.dagbag_stats) == 0
+ path = tmp_path / "testfile.py"
+ path.write_text("")
+ with conf_vars({("core", "dags_folder"): os.fspath(path.parent)}):
+ dagbag = models.DagBag(include_examples=False, safe_mode=True)
+ assert len(dagbag.dagbag_stats) == 0
- def test_safe_mode_disabled(self):
+ def test_safe_mode_disabled(self, tmp_path):
"""With safe mode disabled, an empty python file should be
discovered."""
- with NamedTemporaryFile(dir=self.empty_dir, suffix=".py") as f:
- with conf_vars({("core", "dags_folder"): self.empty_dir}):
- dagbag = models.DagBag(include_examples=False, safe_mode=False)
- assert len(dagbag.dagbag_stats) == 1
- assert dagbag.dagbag_stats[0].file ==
f"/{os.path.basename(f.name)}"
-
- def test_process_file_that_contains_multi_bytes_char(self):
+ path = tmp_path / "testfile.py"
+ path.write_text("")
+ with conf_vars({("core", "dags_folder"): os.fspath(path.parent)}):
+ dagbag = models.DagBag(include_examples=False, safe_mode=False)
+ assert len(dagbag.dagbag_stats) == 1
+ assert dagbag.dagbag_stats[0].file == f"/{path.name}"
+
+ def test_process_file_that_contains_multi_bytes_char(self, tmp_path):
"""
test that we're able to parse file that contains multi-byte char
"""
- with NamedTemporaryFile() as f:
- f.write("\u3042".encode()) # write multi-byte char (hiragana)
- f.flush()
+ path = tmp_path / "testfile"
+ path.write_text("\u3042") # write multi-byte char (hiragana)
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=False)
- assert [] == dagbag.process_file(f.name)
+ dagbag = models.DagBag(dag_folder=os.fspath(path.parent),
include_examples=False)
+ assert [] == dagbag.process_file(os.fspath(path))
- def test_process_file_duplicated_dag_id(self):
+ def test_process_file_duplicated_dag_id(self, tmp_path):
"""Loading a DAG with ID that already existed in a DAG bag should
result in an import error."""
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=False)
+ dagbag = models.DagBag(dag_folder=os.fspath(tmp_path),
include_examples=False)
def create_dag():
from airflow.decorators import dag
@@ -169,23 +164,22 @@ class TestDagBag:
my_dag = my_flow() # noqa
source_lines = [line[12:] for line in
inspect.getsource(create_dag).splitlines(keepends=True)[1:]]
- with NamedTemporaryFile("w+", encoding="utf8") as tf_1,
NamedTemporaryFile(
- "w+", encoding="utf8"
- ) as tf_2:
- tf_1.writelines(source_lines)
- tf_2.writelines(source_lines)
- tf_1.flush()
- tf_2.flush()
-
- found_1 = dagbag.process_file(tf_1.name)
- assert len(found_1) == 1 and found_1[0].dag_id == "my_flow"
- assert dagbag.import_errors == {}
- dags_in_bag = dagbag.dags
-
- found_2 = dagbag.process_file(tf_2.name)
- assert len(found_2) == 0
- assert
dagbag.import_errors[tf_2.name].startswith("AirflowDagDuplicatedIdException:
Ignoring DAG")
- assert dagbag.dags == dags_in_bag # Should not change.
+ path1 = tmp_path / "testfile1"
+ path2 = tmp_path / "testfile2"
+ path1.write_text("".join(source_lines))
+ path2.write_text("".join(source_lines))
+
+ found_1 = dagbag.process_file(os.fspath(path1))
+ assert len(found_1) == 1 and found_1[0].dag_id == "my_flow"
+ assert dagbag.import_errors == {}
+ dags_in_bag = dagbag.dags
+
+ found_2 = dagbag.process_file(os.fspath(path2))
+ assert len(found_2) == 0
+ assert dagbag.import_errors[os.fspath(path2)].startswith(
+ "AirflowDagDuplicatedIdException: Ignoring DAG"
+ )
+ assert dagbag.dags == dags_in_bag # Should not change.
def test_zip_skip_log(self, caplog):
"""
@@ -202,37 +196,39 @@ class TestDagBag:
"assumed to contain no DAGs. Skipping." in caplog.text
)
- def test_zip(self):
+ def test_zip(self, tmp_path):
"""
test the loading of a DAG within a zip file that includes dependencies
"""
syspath_before = deepcopy(sys.path)
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=False)
+ dagbag = models.DagBag(dag_folder=os.fspath(tmp_path),
include_examples=False)
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip"))
assert dagbag.get_dag("test_zip_dag")
assert sys.path == syspath_before # sys.path doesn't change
@patch("airflow.models.dagbag.timeout")
@patch("airflow.models.dagbag.settings.get_dagbag_import_timeout")
- def test_process_dag_file_without_timeout(self,
mocked_get_dagbag_import_timeout, mocked_timeout):
+ def test_process_dag_file_without_timeout(
+ self, mocked_get_dagbag_import_timeout, mocked_timeout, tmp_path
+ ):
"""
Test dag file parsing without timeout
"""
mocked_get_dagbag_import_timeout.return_value = 0
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=False)
+ dagbag = models.DagBag(dag_folder=os.fspath(tmp_path),
include_examples=False)
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER,
"test_default_views.py"))
mocked_timeout.assert_not_called()
mocked_get_dagbag_import_timeout.return_value = -1
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=False)
+ dagbag = models.DagBag(dag_folder=os.fspath(tmp_path),
include_examples=False)
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER,
"test_default_views.py"))
mocked_timeout.assert_not_called()
@patch("airflow.models.dagbag.timeout")
@patch("airflow.models.dagbag.settings.get_dagbag_import_timeout")
def test_process_dag_file_with_non_default_timeout(
- self, mocked_get_dagbag_import_timeout, mocked_timeout
+ self, mocked_get_dagbag_import_timeout, mocked_timeout, tmp_path
):
"""
Test customized dag file parsing timeout
@@ -243,19 +239,21 @@ class TestDagBag:
# ensure the test value is not equal to the default value
assert timeout_value != settings.conf.getfloat("core",
"DAGBAG_IMPORT_TIMEOUT")
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=False)
+ dagbag = models.DagBag(dag_folder=os.fspath(tmp_path),
include_examples=False)
dagbag.process_file(os.path.join(TEST_DAGS_FOLDER,
"test_default_views.py"))
mocked_timeout.assert_called_once_with(timeout_value,
error_message=mock.ANY)
@patch("airflow.models.dagbag.settings.get_dagbag_import_timeout")
- def test_check_value_type_from_get_dagbag_import_timeout(self,
mocked_get_dagbag_import_timeout):
+ def test_check_value_type_from_get_dagbag_import_timeout(
+ self, mocked_get_dagbag_import_timeout, tmp_path
+ ):
"""
Test correctness of value from get_dagbag_import_timeout
"""
mocked_get_dagbag_import_timeout.return_value = "1"
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=False)
+ dagbag = models.DagBag(dag_folder=os.fspath(tmp_path),
include_examples=False)
with pytest.raises(
TypeError, match=r"Value \(1\) from get_dagbag_import_timeout must
be int or float"
):
@@ -267,27 +265,28 @@ class TestDagBag:
@pytest.fixture()
def invalid_cron_zipped_dag(self, invalid_cron_dag: str, tmp_path:
pathlib.Path) -> Iterator[str]:
- zipped = os.path.join(tmp_path, "test_zip_invalid_cron.zip")
+ zipped = tmp_path / "test_zip_invalid_cron.zip"
with zipfile.ZipFile(zipped, "w") as zf:
zf.write(invalid_cron_dag, os.path.basename(invalid_cron_dag))
- yield zipped
- os.unlink(zipped)
+ yield os.fspath(zipped)
@pytest.mark.parametrize("invalid_dag_name", ["invalid_cron_dag",
"invalid_cron_zipped_dag"])
- def test_process_file_cron_validity_check(self, request:
pytest.FixtureRequest, invalid_dag_name: str):
+ def test_process_file_cron_validity_check(
+ self, request: pytest.FixtureRequest, invalid_dag_name: str, tmp_path
+ ):
"""test if an invalid cron expression as schedule interval can be
identified"""
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=False)
+ dagbag = models.DagBag(dag_folder=os.fspath(tmp_path),
include_examples=False)
assert len(dagbag.import_errors) == 0
dagbag.process_file(request.getfixturevalue(invalid_dag_name))
assert len(dagbag.import_errors) == 1
assert len(dagbag.dags) == 0
- def test_process_file_invalid_param_check(self):
+ def test_process_file_invalid_param_check(self, tmp_path):
"""
test if an invalid param in the dag param can be identified
"""
invalid_dag_files = ["test_invalid_param.py"]
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=False)
+ dagbag = models.DagBag(dag_folder=os.fspath(tmp_path),
include_examples=False)
assert len(dagbag.import_errors) == 0
for file in invalid_dag_files:
@@ -351,7 +350,7 @@ class TestDagBag:
)
def test_get_dag_registration(self, file_to_load, expected):
dagbag = models.DagBag(dag_folder=os.devnull, include_examples=False)
- dagbag.process_file(str(file_to_load))
+ dagbag.process_file(os.fspath(file_to_load))
for dag_id, path in expected.items():
dag = dagbag.get_dag(dag_id)
assert dag, f"{dag_id} was bagged"
@@ -366,12 +365,11 @@ class TestDagBag:
def zip_with_valid_dag_and_dup_tasks(self, tmp_path: pathlib.Path) ->
Iterator[str]:
failing_dag_file = TEST_DAGS_FOLDER / "test_invalid_dup_task.py"
working_dag_file = TEST_DAGS_FOLDER / "test_example_bash_operator.py"
- zipped = os.path.join(tmp_path, "test_zip_invalid_dup_task.zip")
+ zipped = tmp_path / "test_zip_invalid_dup_task.zip"
with zipfile.ZipFile(zipped, "w") as zf:
- zf.write(failing_dag_file, os.path.basename(failing_dag_file))
- zf.write(working_dag_file, os.path.basename(working_dag_file))
- yield zipped
- os.unlink(zipped)
+ zf.write(failing_dag_file, failing_dag_file.name)
+ zf.write(working_dag_file, working_dag_file.name)
+ yield os.fspath(zipped)
def test_dag_registration_with_failure_zipped(self,
zip_with_valid_dag_and_dup_tasks):
dagbag = models.DagBag(dag_folder=os.devnull, include_examples=False)
@@ -380,7 +378,7 @@ class TestDagBag:
assert ["test_example_bash_operator"] == [dag.dag_id for dag in found]
@patch.object(DagModel, "get_current")
- def test_refresh_py_dag(self, mock_dagmodel):
+ def test_refresh_py_dag(self, mock_dagmodel, tmp_path):
"""
Test that we can refresh an ordinary .py DAG
"""
@@ -400,7 +398,7 @@ class TestDagBag:
_TestDagBag.process_file_calls += 1
return super().process_file(filepath, only_if_updated,
safe_mode)
- dagbag = _TestDagBag(dag_folder=self.empty_dir, include_examples=True)
+ dagbag = _TestDagBag(dag_folder=os.fspath(tmp_path),
include_examples=True)
assert 1 == dagbag.process_file_calls
dag = dagbag.get_dag(dag_id)
@@ -436,7 +434,7 @@ class TestDagBag:
assert dag_id == dag.dag_id
assert 2 == dagbag.process_file_calls
- def test_dag_removed_if_serialized_dag_is_removed(self, dag_maker):
+ def test_dag_removed_if_serialized_dag_is_removed(self, dag_maker,
tmp_path):
"""
Test that if a DAG does not exist in serialized_dag table (as the DAG
file was removed),
remove dags from the DagBag
@@ -450,7 +448,7 @@ class TestDagBag:
) as dag:
EmptyOperator(task_id="task_1")
dag_maker.create_dagrun()
- dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False,
read_dags_from_db=True)
+ dagbag = DagBag(dag_folder=os.fspath(tmp_path),
include_examples=False, read_dags_from_db=True)
dagbag.dags = {dag.dag_id:
SerializedDAG.from_dict(SerializedDAG.to_dict(dag))}
dagbag.dags_last_fetched = {dag.dag_id: (tz.utcnow() -
timedelta(minutes=2))}
dagbag.dags_hash = {dag.dag_id: mock.ANY}
@@ -462,19 +460,18 @@ class TestDagBag:
assert dag.dag_id not in dagbag.dags_last_fetched
assert dag.dag_id not in dagbag.dags_hash
- def process_dag(self, create_dag):
+ def process_dag(self, create_dag, tmp_path):
"""
Helper method to process a file generated from the input create_dag
function.
"""
# write source to file
source =
textwrap.dedent("".join(inspect.getsource(create_dag).splitlines(True)[1:-1]))
- with NamedTemporaryFile() as f:
- f.write(source.encode("utf8"))
- f.flush()
+ path = tmp_path / "testfile"
+ path.write_text(source)
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=False)
- found_dags = dagbag.process_file(f.name)
- return dagbag, found_dags, f.name
+ dagbag = models.DagBag(dag_folder=os.fspath(path.parent),
include_examples=False)
+ found_dags = dagbag.process_file(os.fspath(path))
+ return dagbag, found_dags, os.fspath(path)
def validate_dags(self, expected_parent_dag, actual_found_dags,
actual_dagbag, should_be_found=True):
expected_dag_ids = list(map(lambda dag: dag.dag_id,
expected_parent_dag.subdags))
@@ -493,7 +490,7 @@ class TestDagBag:
f'be in dagbag.dags after processing dag
"{expected_parent_dag.dag_id}"'
)
- def test_load_subdags(self):
+ def test_load_subdags(self, tmp_path):
# Define Dag to load
def standard_subdag():
import datetime
@@ -539,7 +536,7 @@ class TestDagBag:
assert len(test_dag.subdags) == 2
# Perform processing dag
- dagbag, found_dags, _ = self.process_dag(standard_subdag)
+ dagbag, found_dags, _ = self.process_dag(standard_subdag, tmp_path)
# Validate correctness
# all dags from test_dag should be listed
@@ -623,7 +620,7 @@ class TestDagBag:
assert len(test_dag.subdags) == 6
# Perform processing dag
- dagbag, found_dags, filename = self.process_dag(nested_subdags)
+ dagbag, found_dags, filename = self.process_dag(nested_subdags,
tmp_path)
# Validate correctness
# all dags from test_dag should be listed
@@ -632,7 +629,7 @@ class TestDagBag:
for dag in dagbag.dags.values():
assert dag.fileloc == filename
- def test_skip_cycle_dags(self):
+ def test_skip_cycle_dags(self, tmp_path):
"""
Don't crash when loading an invalid (contains a cycle) DAG file.
Don't load the dag into the DagBag either
@@ -661,7 +658,7 @@ class TestDagBag:
assert len(test_dag.subdags) == 0
# Perform processing dag
- dagbag, found_dags, file_path = self.process_dag(basic_cycle)
+ dagbag, found_dags, file_path = self.process_dag(basic_cycle, tmp_path)
# #Validate correctness
# None of the dags should be found
@@ -748,18 +745,18 @@ class TestDagBag:
assert len(test_dag.subdags) == 6
# Perform processing dag
- dagbag, found_dags, file_path = self.process_dag(nested_subdag_cycle)
+ dagbag, found_dags, file_path = self.process_dag(nested_subdag_cycle,
tmp_path)
# Validate correctness
# None of the dags should be found
self.validate_dags(test_dag, found_dags, dagbag, should_be_found=False)
assert file_path in dagbag.import_errors
- def test_process_file_with_none(self):
+ def test_process_file_with_none(self, tmp_path):
"""
test that process_file can handle Nones
"""
- dagbag = models.DagBag(dag_folder=self.empty_dir,
include_examples=False)
+ dagbag = models.DagBag(dag_folder=os.fspath(tmp_path),
include_examples=False)
assert [] == dagbag.process_file(None)