This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1585e4586aa Refactor airflow core tests to use SQLA2 (#59920)
1585e4586aa is described below
commit 1585e4586aa873bca89d95612fcc6155a54dad9a
Author: RickyChen / 陳昭儒 <[email protected]>
AuthorDate: Tue Dec 30 17:49:24 2025 +0800
Refactor airflow core tests to use SQLA2 (#59920)
* Refactor airflow-core tests to use SQLA2
* Refactor airflow-core tests to use SQLA2
---
.pre-commit-config.yaml | 1 +
.../tests/unit/dag_processing/test_collection.py | 36 +++++++++++-----------
2 files changed, 19 insertions(+), 18 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 897ee78b18e..097d0a32308 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -466,6 +466,7 @@ repos:
^airflow-core/tests/unit/utils/test_types.py$|
^airflow-core/tests/unit/dag_processing/test_manager.py$|
^airflow-core/tests/unit/dag_processing/test_processor.py$|
+ ^airflow-core/tests/unit/dag_processing/test_collection\.py$|
^dev/airflow_perf/scheduler_dag_execution_timing.py$|
^providers/celery/.*\.py$|
^providers/cncf/kubernetes/.*\.py$|
diff --git a/airflow-core/tests/unit/dag_processing/test_collection.py
b/airflow-core/tests/unit/dag_processing/test_collection.py
index 4ba37f859c0..ad4a8a73382 100644
--- a/airflow-core/tests/unit/dag_processing/test_collection.py
+++ b/airflow-core/tests/unit/dag_processing/test_collection.py
@@ -27,7 +27,7 @@ from unittest import mock
from unittest.mock import patch
import pytest
-from sqlalchemy import func, select
+from sqlalchemy import delete, func, select
from sqlalchemy.exc import OperationalError, SAWarning
import airflow.dag_processing.collection
@@ -336,7 +336,7 @@ class TestUpdateDagParsingResults:
self, monkeypatch, spy_agency: SpyAgency, session, time_machine,
testing_dag_bundle
):
"""Test DAG-specific permissions are synced when a DAG is new or
updated"""
- serialized_dags_count =
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+ serialized_dags_count =
session.scalar(select(func.count(SerializedDagModel.dag_id)))
assert serialized_dags_count == 0
time_machine.move_to(tz.datetime(2020, 1, 5, 0, 0, 0), tick=False)
@@ -367,7 +367,7 @@ class TestUpdateDagParsingResults:
_sync_to_db()
spy_agency.assert_spy_called_with(sync_perms_spy, dag, session=session)
- serialized_dags_count =
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+ serialized_dags_count =
session.scalar(select(func.count(SerializedDagModel.dag_id)))
@patch.object(SerializedDagModel, "write_dag")
@patch("airflow.serialization.definitions.dag.SerializedDAG.bulk_write_to_db")
@@ -375,7 +375,7 @@ class TestUpdateDagParsingResults:
self, mock_bulk_write_to_db, mock_s10n_write_dag, testing_dag_bundle,
session
):
"""Test that important DB operations in db sync are retried on
OperationalError"""
- serialized_dags_count =
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+ serialized_dags_count =
session.scalar(select(func.count(SerializedDagModel.dag_id)))
assert serialized_dags_count == 0
mock_dag = mock.MagicMock()
dags = [mock_dag]
@@ -423,12 +423,12 @@ class TestUpdateDagParsingResults:
]
)
- serialized_dags_count =
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+ serialized_dags_count =
session.scalar(select(func.count(SerializedDagModel.dag_id)))
assert serialized_dags_count == 0
def test_serialized_dags_are_written_to_db_on_sync(self,
testing_dag_bundle, session):
"""Test DAGs are Serialized and written to DB when parsing result is
updated"""
- serialized_dags_count =
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+ serialized_dags_count =
session.scalar(select(func.count(SerializedDagModel.dag_id)))
assert serialized_dags_count == 0
dag = DAG(dag_id="test")
@@ -443,7 +443,7 @@ class TestUpdateDagParsingResults:
session=session,
)
- new_serialized_dags_count =
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+ new_serialized_dags_count =
session.scalar(select(func.count(SerializedDagModel.dag_id)))
assert new_serialized_dags_count == 1
def test_parse_time_written_to_db_on_sync(self, testing_dag_bundle,
session):
@@ -485,7 +485,7 @@ class TestUpdateDagParsingResults:
dag_model: DagModel = session.get(DagModel, (dag.dag_id,))
assert dag_model.has_import_errors is True
- import_errors = session.query(ParseImportError).all()
+ import_errors = session.scalars(select(ParseImportError)).all()
assert len(import_errors) == 1
import_error = import_errors[0]
@@ -514,7 +514,7 @@ class TestUpdateDagParsingResults:
"""
Test that import errors related to invalid access control role are
tracked in the DB until being fixed.
"""
- serialized_dags_count =
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+ serialized_dags_count =
session.scalar(select(func.count(SerializedDagModel.dag_id)))
assert serialized_dags_count == 0
time_machine.move_to(tz.datetime(2020, 1, 5, 0, 0, 0), tick=False)
@@ -542,7 +542,7 @@ class TestUpdateDagParsingResults:
# the DAG should contain an import error.
assert dag_model.has_import_errors is True
- prev_import_errors = session.query(ParseImportError).all()
+ prev_import_errors = session.scalars(select(ParseImportError)).all()
# the import error message should match.
assert len(prev_import_errors) == 1
prev_import_error = prev_import_errors[0]
@@ -558,7 +558,7 @@ class TestUpdateDagParsingResults:
)
# the DAG is serialized into the DB.
- serialized_dags_count =
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+ serialized_dags_count =
session.scalar(select(func.count(SerializedDagModel.dag_id)))
assert serialized_dags_count == 1
# run the update again. Even though the DAG is not updated, the
processor should raise import error since the access control is not fixed.
@@ -569,7 +569,7 @@ class TestUpdateDagParsingResults:
# the DAG should contain an import error.
assert dag_model.has_import_errors is True
- import_errors = session.query(ParseImportError).all()
+ import_errors = session.scalars(select(ParseImportError)).all()
# the import error should still in the DB.
assert len(import_errors) == 1
import_error = import_errors[0]
@@ -599,7 +599,7 @@ class TestUpdateDagParsingResults:
# the import error should be cleared.
assert dag_model.has_import_errors is False
- import_errors = session.query(ParseImportError).all()
+ import_errors = session.scalars(select(ParseImportError)).all()
# the import error should be cleared.
assert len(import_errors) == 0
@@ -640,10 +640,10 @@ class TestUpdateDagParsingResults:
files_parsed={("testing", "abc.py")},
)
- import_error = (
- session.query(ParseImportError)
- .filter(ParseImportError.filename == filename,
ParseImportError.bundle_name == bundle_name)
- .one()
+ import_error = session.scalar(
+ select(ParseImportError).where(
+ ParseImportError.filename == filename,
ParseImportError.bundle_name == bundle_name
+ )
)
# assert that the ID of the import error did not change
@@ -981,7 +981,7 @@ class TestUpdateDagTags:
@pytest.fixture(autouse=True)
def setup_teardown(self, session):
yield
- session.query(DagModel).filter(DagModel.dag_id == "test_dag").delete()
+ session.execute(delete(DagModel).where(DagModel.dag_id == "test_dag"))
session.commit()
@pytest.mark.parametrize(