This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5b4908d7921 Add test to prevent ORM references in migration scripts 
(#60176)
5b4908d7921 is described below

commit 5b4908d79214e2a9a0e35a6d477ef10197f87e26
Author: Zach <[email protected]>
AuthorDate: Tue Jan 13 14:08:58 2026 -0500

    Add test to prevent ORM references in migration scripts (#60176)
    
    * Add pre-commit check for ORM references in migration scripts
    
    * Inspect inheritance of models references in hook
    
    * Run ORM reference check as a unit test rather than prek target
    
    * Update unit test modules list in failing test
---
 .../tests/unit/always/test_project_structure.py    | 16 +----
 airflow-core/tests/unit/migrations/__init__.py     | 16 +++++
 .../test_no_orm_refs_in_migration_scripts.py       | 68 ++++++++++++++++++++++
 .../tests/test_pytest_args_for_test_types.py       |  1 +
 .../src/tests_common/test_utils/file_loading.py    | 16 +++++
 scripts/ci/prek/check_tests_in_right_folders.py    |  1 +
 6 files changed, 103 insertions(+), 15 deletions(-)

diff --git a/airflow-core/tests/unit/always/test_project_structure.py 
b/airflow-core/tests/unit/always/test_project_structure.py
index b3e0236fd32..3126224a420 100644
--- a/airflow-core/tests/unit/always/test_project_structure.py
+++ b/airflow-core/tests/unit/always/test_project_structure.py
@@ -24,6 +24,7 @@ import pathlib
 
 import pytest
 
+from tests_common.test_utils.file_loading import get_imports_from_file
 from tests_common.test_utils.paths import (
     AIRFLOW_CORE_SOURCES_PATH,
     AIRFLOW_PROVIDERS_ROOT_PATH,
@@ -248,21 +249,6 @@ class TestProjectStructure:
         )
 
 
-def get_imports_from_file(filepath: str):
-    with open(filepath) as py_file:
-        content = py_file.read()
-    doc_node = ast.parse(content, filepath)
-    import_names: set[str] = set()
-    for current_node in ast.walk(doc_node):
-        if not isinstance(current_node, (ast.Import, ast.ImportFrom)):
-            continue
-        for alias in current_node.names:
-            name = alias.name
-            fullname = f"{current_node.module}.{name}" if 
isinstance(current_node, ast.ImportFrom) else name
-            import_names.add(fullname)
-    return import_names
-
-
 def filepath_to_module(path: pathlib.Path, src_folder: pathlib.Path):
     path = path.relative_to(src_folder)
     return path.as_posix().replace("/", ".")[: -(len(".py"))]
diff --git a/airflow-core/tests/unit/migrations/__init__.py 
b/airflow-core/tests/unit/migrations/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/airflow-core/tests/unit/migrations/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git 
a/airflow-core/tests/unit/migrations/test_no_orm_refs_in_migration_scripts.py 
b/airflow-core/tests/unit/migrations/test_no_orm_refs_in_migration_scripts.py
new file mode 100644
index 00000000000..93a1ed2c7c8
--- /dev/null
+++ 
b/airflow-core/tests/unit/migrations/test_no_orm_refs_in_migration_scripts.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Check that there are no imports of ORM classes in any of the alembic migration 
scripts.
+This is to prevent the addition of migration code directly referencing any ORM 
definition,
+which could potentially break downgrades. For more details, refer to the 
relevant discussion
+thread at this link: https://github.com/apache/airflow/issues/59871
+"""
+
+from __future__ import annotations
+
+import importlib
+import inspect
+import os
+from pathlib import Path
+from pprint import pformat
+from typing import Final
+
+import pytest
+
+from airflow.models.base import Base
+
+from tests_common.test_utils.file_loading import get_imports_from_file
+from tests_common.test_utils.paths import AIRFLOW_CORE_SOURCES_PATH
+
+_MIGRATIONS_DIRPATH: Final[Path] = Path(
+    os.path.join(AIRFLOW_CORE_SOURCES_PATH, "airflow/migrations/versions")
+)
+
+
[email protected](
+    "migration_script_path",
+    [pytest.param(path, id=os.path.basename(path)) for path in 
list(_MIGRATIONS_DIRPATH.glob("**/*.py"))],
+)
+def test_migration_script_has_no_orm_references(migration_script_path: Path) 
-> None:
+    """Ensures the given alembic migration script path does not contain any 
ORM imports."""
+    bad_imports = []
+    for import_ref in 
get_imports_from_file(filepath=str(migration_script_path)):
+        if _is_violating_orm_import(import_ref=import_ref):
+            bad_imports.append(import_ref)
+    assert not bad_imports, f"{str(migration_script_path)} has bad ORM 
imports: {pformat(bad_imports)}"
+
+
+def _is_violating_orm_import(import_ref: str) -> bool:
+    """Return `True` if the imported object is an ORM class from within 
`airflow.models`, otherwise return `False`."""
+    if not import_ref.startswith("airflow.models"):
+        return False
+    # import the fully qualified reference to check if the reference is a 
subclass of a declarative base
+    mod_to_import, _, attr_name = import_ref.rpartition(".")
+    referenced_module = importlib.import_module(mod_to_import)
+    referenced_obj = getattr(referenced_module, attr_name)
+    if inspect.isclass(referenced_obj) and referenced_obj in 
Base.__subclasses__():
+        return True
+    return False
diff --git a/dev/breeze/tests/test_pytest_args_for_test_types.py 
b/dev/breeze/tests/test_pytest_args_for_test_types.py
index b2b8193c5eb..982dc4e37b1 100644
--- a/dev/breeze/tests/test_pytest_args_for_test_types.py
+++ b/dev/breeze/tests/test_pytest_args_for_test_types.py
@@ -167,6 +167,7 @@ def _find_all_integration_folders() -> list[str]:
                 "airflow-core/tests/unit/listeners",
                 "airflow-core/tests/unit/logging",
                 "airflow-core/tests/unit/macros",
+                "airflow-core/tests/unit/migrations",
                 "airflow-core/tests/unit/observability",
                 "airflow-core/tests/unit/plugins",
                 "airflow-core/tests/unit/security",
diff --git a/devel-common/src/tests_common/test_utils/file_loading.py 
b/devel-common/src/tests_common/test_utils/file_loading.py
index 2653a1f41a3..b05f35fc157 100644
--- a/devel-common/src/tests_common/test_utils/file_loading.py
+++ b/devel-common/src/tests_common/test_utils/file_loading.py
@@ -17,6 +17,7 @@
 
 from __future__ import annotations
 
+import ast
 import json
 import re
 from os.path import join
@@ -44,3 +45,18 @@ def load_file_from_resources(*args: str, mode="r", 
encoding="utf-8"):
         if mode == "r":
             return remove_license_header(file.read())
         return file.read()
+
+
+def get_imports_from_file(filepath: str) -> set[str]:
+    with open(filepath) as py_file:
+        content = py_file.read()
+    doc_node = ast.parse(content, filepath)
+    import_names: set[str] = set()
+    for current_node in ast.walk(doc_node):
+        if not isinstance(current_node, (ast.Import, ast.ImportFrom)):
+            continue
+        for alias in current_node.names:
+            name = alias.name
+            fullname = f"{current_node.module}.{name}" if 
isinstance(current_node, ast.ImportFrom) else name
+            import_names.add(fullname)
+    return import_names
diff --git a/scripts/ci/prek/check_tests_in_right_folders.py 
b/scripts/ci/prek/check_tests_in_right_folders.py
index 43d28d7e4af..c5205903ca1 100755
--- a/scripts/ci/prek/check_tests_in_right_folders.py
+++ b/scripts/ci/prek/check_tests_in_right_folders.py
@@ -61,6 +61,7 @@ POSSIBLE_TEST_FOLDERS = [
     "listeners",
     "logging",
     "macros",
+    "migrations",
     "models",
     "notifications",
     "observability",

Reply via email to