This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1585e4586aa Refactor airflow core tests to use SQLA2 (#59920)
1585e4586aa is described below

commit 1585e4586aa873bca89d95612fcc6155a54dad9a
Author: RickyChen / 陳昭儒 <[email protected]>
AuthorDate: Tue Dec 30 17:49:24 2025 +0800

    Refactor airflow core tests to use SQLA2 (#59920)
    
    * Refactor airflow-core tests to use SQLA2
    
    * Refactor airflow-core tests to use SQLA2
---
 .pre-commit-config.yaml                            |  1 +
 .../tests/unit/dag_processing/test_collection.py   | 36 +++++++++++-----------
 2 files changed, 19 insertions(+), 18 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 897ee78b18e..097d0a32308 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -466,6 +466,7 @@ repos:
             ^airflow-core/tests/unit/utils/test_types.py$|
             ^airflow-core/tests/unit/dag_processing/test_manager.py$|
             ^airflow-core/tests/unit/dag_processing/test_processor.py$|
+            ^airflow-core/tests/unit/dag_processing/test_collection\.py$|
             ^dev/airflow_perf/scheduler_dag_execution_timing.py$|
             ^providers/celery/.*\.py$|
             ^providers/cncf/kubernetes/.*\.py$|
diff --git a/airflow-core/tests/unit/dag_processing/test_collection.py 
b/airflow-core/tests/unit/dag_processing/test_collection.py
index 4ba37f859c0..ad4a8a73382 100644
--- a/airflow-core/tests/unit/dag_processing/test_collection.py
+++ b/airflow-core/tests/unit/dag_processing/test_collection.py
@@ -27,7 +27,7 @@ from unittest import mock
 from unittest.mock import patch
 
 import pytest
-from sqlalchemy import func, select
+from sqlalchemy import delete, func, select
 from sqlalchemy.exc import OperationalError, SAWarning
 
 import airflow.dag_processing.collection
@@ -336,7 +336,7 @@ class TestUpdateDagParsingResults:
         self, monkeypatch, spy_agency: SpyAgency, session, time_machine, 
testing_dag_bundle
     ):
         """Test DAG-specific permissions are synced when a DAG is new or 
updated"""
-        serialized_dags_count = 
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+        serialized_dags_count = 
session.scalar(select(func.count(SerializedDagModel.dag_id)))
         assert serialized_dags_count == 0
 
         time_machine.move_to(tz.datetime(2020, 1, 5, 0, 0, 0), tick=False)
@@ -367,7 +367,7 @@ class TestUpdateDagParsingResults:
         _sync_to_db()
         spy_agency.assert_spy_called_with(sync_perms_spy, dag, session=session)
 
-        serialized_dags_count = 
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+        serialized_dags_count = 
session.scalar(select(func.count(SerializedDagModel.dag_id)))
 
     @patch.object(SerializedDagModel, "write_dag")
     
@patch("airflow.serialization.definitions.dag.SerializedDAG.bulk_write_to_db")
@@ -375,7 +375,7 @@ class TestUpdateDagParsingResults:
         self, mock_bulk_write_to_db, mock_s10n_write_dag, testing_dag_bundle, 
session
     ):
         """Test that important DB operations in db sync are retried on 
OperationalError"""
-        serialized_dags_count = 
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+        serialized_dags_count = 
session.scalar(select(func.count(SerializedDagModel.dag_id)))
         assert serialized_dags_count == 0
         mock_dag = mock.MagicMock()
         dags = [mock_dag]
@@ -423,12 +423,12 @@ class TestUpdateDagParsingResults:
             ]
         )
 
-        serialized_dags_count = 
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+        serialized_dags_count = 
session.scalar(select(func.count(SerializedDagModel.dag_id)))
         assert serialized_dags_count == 0
 
     def test_serialized_dags_are_written_to_db_on_sync(self, 
testing_dag_bundle, session):
         """Test DAGs are Serialized and written to DB when parsing result is 
updated"""
-        serialized_dags_count = 
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+        serialized_dags_count = 
session.scalar(select(func.count(SerializedDagModel.dag_id)))
         assert serialized_dags_count == 0
 
         dag = DAG(dag_id="test")
@@ -443,7 +443,7 @@ class TestUpdateDagParsingResults:
             session=session,
         )
 
-        new_serialized_dags_count = 
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+        new_serialized_dags_count = 
session.scalar(select(func.count(SerializedDagModel.dag_id)))
         assert new_serialized_dags_count == 1
 
     def test_parse_time_written_to_db_on_sync(self, testing_dag_bundle, 
session):
@@ -485,7 +485,7 @@ class TestUpdateDagParsingResults:
         dag_model: DagModel = session.get(DagModel, (dag.dag_id,))
         assert dag_model.has_import_errors is True
 
-        import_errors = session.query(ParseImportError).all()
+        import_errors = session.scalars(select(ParseImportError)).all()
 
         assert len(import_errors) == 1
         import_error = import_errors[0]
@@ -514,7 +514,7 @@ class TestUpdateDagParsingResults:
         """
         Test that import errors related to invalid access control role are 
tracked in the DB until being fixed.
         """
-        serialized_dags_count = 
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+        serialized_dags_count = 
session.scalar(select(func.count(SerializedDagModel.dag_id)))
         assert serialized_dags_count == 0
         time_machine.move_to(tz.datetime(2020, 1, 5, 0, 0, 0), tick=False)
 
@@ -542,7 +542,7 @@ class TestUpdateDagParsingResults:
         # the DAG should contain an import error.
         assert dag_model.has_import_errors is True
 
-        prev_import_errors = session.query(ParseImportError).all()
+        prev_import_errors = session.scalars(select(ParseImportError)).all()
         # the import error message should match.
         assert len(prev_import_errors) == 1
         prev_import_error = prev_import_errors[0]
@@ -558,7 +558,7 @@ class TestUpdateDagParsingResults:
         )
 
         # the DAG is serialized into the DB.
-        serialized_dags_count = 
session.query(func.count(SerializedDagModel.dag_id)).scalar()
+        serialized_dags_count = 
session.scalar(select(func.count(SerializedDagModel.dag_id)))
         assert serialized_dags_count == 1
 
         # run the update again. Even though the DAG is not updated, the 
processor should raise import error since the access control is not fixed.
@@ -569,7 +569,7 @@ class TestUpdateDagParsingResults:
         # the DAG should contain an import error.
         assert dag_model.has_import_errors is True
 
-        import_errors = session.query(ParseImportError).all()
+        import_errors = session.scalars(select(ParseImportError)).all()
         # the import error should still in the DB.
         assert len(import_errors) == 1
         import_error = import_errors[0]
@@ -599,7 +599,7 @@ class TestUpdateDagParsingResults:
         # the import error should be cleared.
         assert dag_model.has_import_errors is False
 
-        import_errors = session.query(ParseImportError).all()
+        import_errors = session.scalars(select(ParseImportError)).all()
         # the import error should be cleared.
         assert len(import_errors) == 0
 
@@ -640,10 +640,10 @@ class TestUpdateDagParsingResults:
             files_parsed={("testing", "abc.py")},
         )
 
-        import_error = (
-            session.query(ParseImportError)
-            .filter(ParseImportError.filename == filename, 
ParseImportError.bundle_name == bundle_name)
-            .one()
+        import_error = session.scalar(
+            select(ParseImportError).where(
+                ParseImportError.filename == filename, 
ParseImportError.bundle_name == bundle_name
+            )
         )
 
         # assert that the ID of the import error did not change
@@ -981,7 +981,7 @@ class TestUpdateDagTags:
     @pytest.fixture(autouse=True)
     def setup_teardown(self, session):
         yield
-        session.query(DagModel).filter(DagModel.dag_id == "test_dag").delete()
+        session.execute(delete(DagModel).where(DagModel.dag_id == "test_dag"))
         session.commit()
 
     @pytest.mark.parametrize(

Reply via email to