This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e17ecd3f89 Avoid imports from "providers" (#46801)
4e17ecd3f89 is described below

commit 4e17ecd3f892497e910f4e7df7ecb007a7f3d039
Author: Jarek Potiuk <[email protected]>
AuthorDate: Sun Feb 16 16:12:46 2025 +0100

    Avoid imports from "providers" (#46801)
---
 contributing-docs/11_provider_packages.rst         | 14 +++++++
 .../unit/apache/hive/transfers/test_s3_to_hive.py  |  7 ++--
 .../tests/unit/fab/auth_manager/models/test_db.py  |  2 +-
 .../unit/microsoft/azure/hooks/test_msgraph.py     |  8 ++--
 .../unit/microsoft/azure/operators/test_msgraph.py | 22 +++++++----
 .../unit/microsoft/azure/sensors/test_msgraph.py   |  8 ++--
 .../azure/tests/unit/microsoft/azure/test_utils.py | 27 -------------
 .../unit/microsoft/azure/triggers/test_msgraph.py  | 22 +++++++----
 .../tests/unit/microsoft/mssql/hooks/test_mssql.py |  6 ++-
 pyproject.toml                                     |  2 +-
 tests_common/test_utils/file_loading.py            | 46 ++++++++++++++++++++++
 11 files changed, 107 insertions(+), 57 deletions(-)

diff --git a/contributing-docs/11_provider_packages.rst 
b/contributing-docs/11_provider_packages.rst
index c7c28949b26..383f8d298eb 100644
--- a/contributing-docs/11_provider_packages.rst
+++ b/contributing-docs/11_provider_packages.rst
@@ -74,6 +74,9 @@ PROVIDER is the name of the provider package. It might be 
single directory (goog
 cases we have a nested structure one level down (``apache/cassandra``, 
``apache/druid``, ``microsoft/winrm``,
 ``common.io`` for example).
 
+What are the pyproject.toml and provider.yaml files
+---------------------------------------------------
+
 On top of the standard ``pyproject.toml`` file where we keep project 
information,
 we have ``provider.yaml`` file in the provider's module of the ``providers``.
 
@@ -92,6 +95,9 @@ not modify it - except updating dependencies, as your changes 
will be lost.
 Eventually we might migrate ``provider.yaml`` fully to ``pyproject.toml`` file 
but it would require custom
 ``tool.airflow`` toml section to be added to the ``pyproject.toml`` file.
 
+How to manage provider's dependencies
+-------------------------------------
+
 If you want to add dependencies to the provider, you should add them to the 
corresponding ``pyproject.toml``
 file.
 
@@ -115,6 +121,14 @@ package might be installed when breeze is restarted or by 
your IDE or by running
 or when you run ``pip install -e "./providers"`` or ``pip install -e 
"./providers/<PROVIDER>"`` for the new
 provider structure.
 
+How to reuse code between tests in different providers
+------------------------------------------------------
+
+When you develop providers, you might want to reuse some of the code between 
tests in different providers.
+This is possible by placing the code in ``test_utils`` in the ``tests_common`` 
directory. The ``tests_common``
+module is automatically available in the ``sys.path`` when running tests for 
the providers and you can
+import common code from there.
+
 Chicken-egg providers
 ---------------------
 
diff --git 
a/providers/apache/hive/tests/unit/apache/hive/transfers/test_s3_to_hive.py 
b/providers/apache/hive/tests/unit/apache/hive/transfers/test_s3_to_hive.py
index c0f241bbe87..4d54b6248cf 100644
--- a/providers/apache/hive/tests/unit/apache/hive/transfers/test_s3_to_hive.py
+++ b/providers/apache/hive/tests/unit/apache/hive/transfers/test_s3_to_hive.py
@@ -29,10 +29,11 @@ from unittest import mock
 
 import pytest
 
-import providers.microsoft.azure.tests.unit.microsoft.azure.test_utils
 from airflow.exceptions import AirflowException
 from airflow.providers.apache.hive.transfers.s3_to_hive import 
S3ToHiveOperator, uncompress_file
 
+import tests_common.test_utils.file_loading
+
 boto3 = pytest.importorskip("boto3")
 moto = pytest.importorskip("moto")
 logger = logging.getLogger(__name__)
@@ -218,10 +219,10 @@ class TestS3ToHiveTransfer:
 
             # Upload the file into the Mocked S3 bucket
             conn.upload_file(ip_fn, "bucket", self.s3_key + ext)
-
             # file parameter to HiveCliHook.load_file is compared
             # against expected file output
-            
providers.microsoft.azure.tests.unit.microsoft.azure.test_utils.load_file.side_effect
 = (
+
+            
tests_common.test_utils.file_loading.load_file_from_resources.side_effect = (
                 lambda *args, **kwargs: self._load_file_side_effect(args, 
op_fn, ext)
             )
             # Execute S3ToHiveTransfer
diff --git a/providers/fab/tests/unit/fab/auth_manager/models/test_db.py 
b/providers/fab/tests/unit/fab/auth_manager/models/test_db.py
index c9ef74d3d62..f0920ebb151 100644
--- a/providers/fab/tests/unit/fab/auth_manager/models/test_db.py
+++ b/providers/fab/tests/unit/fab/auth_manager/models/test_db.py
@@ -24,7 +24,7 @@ from alembic.autogenerate import compare_metadata
 from alembic.migration import MigrationContext
 from sqlalchemy import MetaData
 
-import providers.fab.src.airflow.providers.fab as provider_fab
+import airflow.providers.fab as provider_fab
 from airflow.settings import engine
 from airflow.utils.db import (
     compare_server_default,
diff --git 
a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py 
b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
index 6a88739b076..c7604a26665 100644
--- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
+++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 import asyncio
 import inspect
 from json import JSONDecodeError
+from os.path import dirname
 from typing import TYPE_CHECKING
 from unittest.mock import Mock, patch
 
@@ -44,13 +45,12 @@ from airflow.providers.microsoft.azure.hooks.msgraph import 
(
 )
 from unit.microsoft.azure.test_utils import (
     get_airflow_connection,
-    load_file,
-    load_json,
     mock_connection,
     mock_json_response,
     mock_response,
 )
 
+from tests_common.test_utils.file_loading import load_file_from_resources, 
load_json_from_resources
 from tests_common.test_utils.providers import get_provider_min_airflow_version
 
 if TYPE_CHECKING:
@@ -313,7 +313,7 @@ class TestKiotaRequestAdapterHook:
 
 class TestResponseHandler:
     def test_default_response_handler_when_json(self):
-        users = load_json("resources", "users.json")
+        users = load_json_from_resources(dirname(__file__), "..", "resources", 
"users.json")
         response = mock_json_response(200, users)
 
         actual = 
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
@@ -329,7 +329,7 @@ class TestResponseHandler:
         assert actual == {}
 
     def test_default_response_handler_when_content(self):
-        users = load_file("resources", "users.json").encode()
+        users = load_file_from_resources(dirname(__file__), "..", "resources", 
"users.json").encode()
         response = mock_response(200, users)
 
         actual = 
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
diff --git 
a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_msgraph.py
 
b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_msgraph.py
index 1bec5670ca0..a26a7f83451 100644
--- 
a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_msgraph.py
+++ 
b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_msgraph.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 import json
 import locale
 from base64 import b64encode
+from os.path import dirname
 from typing import TYPE_CHECKING, Any
 
 import pytest
@@ -27,8 +28,9 @@ from airflow.exceptions import AirflowException
 from airflow.providers.microsoft.azure.operators.msgraph import 
MSGraphAsyncOperator
 from airflow.triggers.base import TriggerEvent
 from unit.microsoft.azure.base import Base
-from unit.microsoft.azure.test_utils import load_file, load_json, 
mock_json_response, mock_response
+from unit.microsoft.azure.test_utils import mock_json_response, mock_response
 
+from tests_common.test_utils.file_loading import load_file_from_resources, 
load_json_from_resources
 from tests_common.test_utils.mock_context import mock_context
 from tests_common.test_utils.operators.run_deferrable import execute_operator
 from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS
@@ -44,8 +46,8 @@ if TYPE_CHECKING:
 class TestMSGraphAsyncOperator(Base):
     @pytest.mark.db_test
     def test_execute(self):
-        users = load_json("resources", "users.json")
-        next_users = load_json("resources", "next_users.json")
+        users = load_json_from_resources(dirname(__file__), "..", "resources", 
"users.json")
+        next_users = load_json_from_resources(dirname(__file__), "..", 
"resources", "next_users.json")
         response = mock_json_response(200, users, next_users)
 
         with self.patch_hook_and_request_adapter(response):
@@ -72,7 +74,7 @@ class TestMSGraphAsyncOperator(Base):
 
     @pytest.mark.db_test
     def test_execute_when_do_xcom_push_is_false(self):
-        users = load_json("resources", "users.json")
+        users = load_json_from_resources(dirname(__file__), "..", "resources", 
"users.json")
         users.pop("@odata.nextLink")
         response = mock_json_response(200, users)
 
@@ -134,7 +136,9 @@ class TestMSGraphAsyncOperator(Base):
 
     @pytest.mark.db_test
     def test_execute_when_response_is_bytes(self):
-        content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
+        content = load_file_from_resources(
+            dirname(__file__), "..", "resources", "dummy.pdf", mode="rb", 
encoding=None
+        )
         base64_encoded_content = 
b64encode(content).decode(locale.getpreferredencoding())
         drive_id = "82f9d24d-6891-4790-8b6d-f1b2a1d0ca22"
         response = mock_response(200, content)
@@ -161,7 +165,9 @@ class TestMSGraphAsyncOperator(Base):
     @pytest.mark.db_test
     @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Lambda parameters 
works in Airflow >= 2.10.0")
     def test_execute_with_lambda_parameter_when_response_is_bytes(self):
-        content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
+        content = load_file_from_resources(
+            dirname(__file__), "..", "resources", "dummy.pdf", mode="rb", 
encoding=None
+        )
         base64_encoded_content = 
b64encode(content).decode(locale.getpreferredencoding())
         drive_id = "82f9d24d-6891-4790-8b6d-f1b2a1d0ca22"
         response = mock_response(200, content)
@@ -202,7 +208,7 @@ class TestMSGraphAsyncOperator(Base):
             url="users",
         )
         context = mock_context(task=operator)
-        response = load_json("resources", "users.json")
+        response = load_json_from_resources(dirname(__file__), "..", 
"resources", "users.json")
         next_link, query_parameters = MSGraphAsyncOperator.paginate(operator, 
response, context)
 
         assert next_link == response["@odata.nextLink"]
@@ -216,7 +222,7 @@ class TestMSGraphAsyncOperator(Base):
             query_parameters={"$top": 12},
         )
         context = mock_context(task=operator)
-        response = load_json("resources", "users.json")
+        response = load_json_from_resources(dirname(__file__), "..", 
"resources", "users.json")
         response["@odata.count"] = 100
         url, query_parameters = MSGraphAsyncOperator.paginate(operator, 
response, context)
 
diff --git 
a/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_msgraph.py 
b/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_msgraph.py
index d08e3d16103..5f42a7c16ab 100644
--- 
a/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_msgraph.py
+++ 
b/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_msgraph.py
@@ -18,21 +18,23 @@ from __future__ import annotations
 
 import json
 from datetime import datetime
+from os.path import dirname
 
 import pytest
 
 from airflow.providers.microsoft.azure.sensors.msgraph import MSGraphSensor
 from airflow.triggers.base import TriggerEvent
 from unit.microsoft.azure.base import Base
-from unit.microsoft.azure.test_utils import load_json, mock_json_response
+from unit.microsoft.azure.test_utils import mock_json_response
 
+from tests_common.test_utils.file_loading import load_json_from_resources
 from tests_common.test_utils.operators.run_deferrable import execute_operator
 from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS
 
 
 class TestMSGraphSensor(Base):
     def test_execute(self):
-        status = load_json("resources", "status.json")
+        status = load_json_from_resources(dirname(__file__), "..", 
"resources", "status.json")
         response = mock_json_response(200, *status)
 
         with self.patch_hook_and_request_adapter(response):
@@ -65,7 +67,7 @@ class TestMSGraphSensor(Base):
 
     @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Lambda parameters 
works in Airflow >= 2.10.0")
     def test_execute_with_lambda_parameter(self):
-        status = load_json("resources", "status.json")
+        status = load_json_from_resources(dirname(__file__), "..", 
"resources", "status.json")
         response = mock_json_response(200, *status)
 
         with self.patch_hook_and_request_adapter(response):
diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/test_utils.py 
b/providers/microsoft/azure/tests/unit/microsoft/azure/test_utils.py
index c246444b44e..54a82c67a7b 100644
--- a/providers/microsoft/azure/tests/unit/microsoft/azure/test_utils.py
+++ b/providers/microsoft/azure/tests/unit/microsoft/azure/test_utils.py
@@ -17,10 +17,7 @@
 
 from __future__ import annotations
 
-import json
-import re
 from json import JSONDecodeError
-from os.path import dirname, join
 from typing import Any
 from unittest import mock
 from unittest.mock import MagicMock
@@ -237,27 +234,3 @@ def mock_response(status_code, content: Any = None, 
headers: dict | None = None)
     response.content = content
     response.json.side_effect = JSONDecodeError("", "", 0)
     return response
-
-
-def remove_license_header(content: str) -> str:
-    """Remove license header from the given content."""
-    # Define the pattern to match both block and single-line comments
-    pattern = r"(/\*.*?\*/)|(--.*?(\r?\n|\r))|(#.*?(\r?\n|\r))"
-
-    # Check if there is a license header at the beginning of the file
-    if re.match(pattern, content, flags=re.DOTALL):
-        # Use re.DOTALL to allow .* to match newline characters in block 
comments
-        return re.sub(pattern, "", content, flags=re.DOTALL).strip()
-    return content.strip()
-
-
-def load_json(*args: str):
-    with open(join(dirname(__file__), *args), encoding="utf-8") as file:
-        return json.load(file)
-
-
-def load_file(*args: str, mode="r", encoding="utf-8"):
-    with open(join(dirname(__file__), *args), mode=mode, encoding=encoding) as 
file:
-        if mode == "r":
-            return remove_license_header(file.read())
-        return file.read()
diff --git 
a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py 
b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py
index d15edb25ab6..8d8f9e84d53 100644
--- 
a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py
+++ 
b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py
@@ -20,6 +20,7 @@ import json
 import locale
 from base64 import b64decode, b64encode
 from datetime import datetime
+from os.path import dirname
 from unittest.mock import patch
 from uuid import uuid4
 
@@ -36,18 +37,17 @@ from airflow.triggers.base import TriggerEvent
 from unit.microsoft.azure.base import Base
 from unit.microsoft.azure.test_utils import (
     get_airflow_connection,
-    load_file,
-    load_json,
     mock_json_response,
     mock_response,
 )
 
+from tests_common.test_utils.file_loading import load_file_from_resources, 
load_json_from_resources
 from tests_common.test_utils.operators.run_deferrable import run_trigger
 
 
 class TestMSGraphTrigger(Base):
     def test_run_when_valid_response(self):
-        users = load_json("resources", "users.json")
+        users = load_json_from_resources(dirname(__file__), "..", "resources", 
"users.json")
         response = mock_json_response(200, users)
 
         with self.patch_hook_and_request_adapter(response):
@@ -83,7 +83,9 @@ class TestMSGraphTrigger(Base):
             assert actual.payload["message"] == ""
 
     def test_run_when_response_is_bytes(self):
-        content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
+        content = load_file_from_resources(
+            dirname(__file__), "..", "resources", "dummy.pdf", mode="rb", 
encoding=None
+        )
         base64_encoded_content = 
b64encode(content).decode(locale.getpreferredencoding())
         response = mock_response(200, content)
 
@@ -138,7 +140,9 @@ class TestMSGraphTrigger(Base):
 
 class TestResponseSerializer:
     def test_serialize_when_bytes_then_base64_encoded(self):
-        response = load_file("resources", "dummy.pdf", mode="rb", 
encoding=None)
+        response = load_file_from_resources(
+            dirname(__file__), "..", "resources", "dummy.pdf", mode="rb", 
encoding=None
+        )
         content = b64encode(response).decode(locale.getpreferredencoding())
 
         actual = ResponseSerializer().serialize(response)
@@ -163,15 +167,17 @@ class TestResponseSerializer:
         )
 
     def test_deserialize_when_json(self):
-        response = load_file("resources", "users.json")
+        response = load_file_from_resources(dirname(__file__), "..", 
"resources", "users.json")
 
         actual = ResponseSerializer().deserialize(response)
 
         assert isinstance(actual, dict)
-        assert actual == load_json("resources", "users.json")
+        assert actual == load_json_from_resources(dirname(__file__), "..", 
"resources", "users.json")
 
     def test_deserialize_when_base64_encoded_string(self):
-        content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
+        content = load_file_from_resources(
+            dirname(__file__), "..", "resources", "dummy.pdf", mode="rb", 
encoding=None
+        )
         response = b64encode(content).decode(locale.getpreferredencoding())
 
         actual = ResponseSerializer().deserialize(response)
diff --git 
a/providers/microsoft/mssql/tests/unit/microsoft/mssql/hooks/test_mssql.py 
b/providers/microsoft/mssql/tests/unit/microsoft/mssql/hooks/test_mssql.py
index 43332992e11..2f9e81f66ac 100644
--- a/providers/microsoft/mssql/tests/unit/microsoft/mssql/hooks/test_mssql.py
+++ b/providers/microsoft/mssql/tests/unit/microsoft/mssql/hooks/test_mssql.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+from os.path import dirname
 from unittest import mock
 
 import pytest
@@ -25,7 +26,8 @@ import sqlalchemy
 from airflow.configuration import conf
 from airflow.models import Connection
 from airflow.providers.microsoft.mssql.dialects.mssql import MsSqlDialect
-from unit.microsoft.mssql.test_utils import load_file
+
+from tests_common.test_utils.file_loading import load_file_from_resources
 
 try:
     from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
@@ -286,7 +288,7 @@ class TestMsSqlHook:
             ],
             replace=True,
         )
-        assert sql == load_file("resources", "replace.sql")
+        assert sql == load_file_from_resources(dirname(__file__), "..", 
"resources", "replace.sql")
 
     def test_dialect_name(self):
         hook = MsSqlHook()
diff --git a/pyproject.toml b/pyproject.toml
index a00c53dd503..65f7f6d0539 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -469,6 +469,7 @@ banned-module-level-imports = ["numpy", "pandas"]
 "sqlalchemy.ext.declarative.as_declarative".msg = "Use 
`sqlalchemy.orm.as_declarative`. Moved in SQLAlchemy 2.0"
 "sqlalchemy.ext.declarative.has_inherited_table".msg = "Use 
`sqlalchemy.orm.has_inherited_table`. Moved in SQLAlchemy 2.0"
 "sqlalchemy.ext.declarative.synonym_for".msg = "Use 
`sqlalchemy.orm.synonym_for`. Moved in SQLAlchemy 2.0"
+"providers".msg = "You should not import 'providers' as a Python module. 
Imports in providers should be done starting from 'src' or `tests' folders, for 
example 'from airflow.providers.airbyte'  or 'from unit.airbyte' or 'from 
system.airbyte'"
 
 [tool.ruff.lint.flake8-type-checking]
 exempt-modules = ["typing", "typing_extensions"]
@@ -542,7 +543,6 @@ python_files = [
 testpaths = [
     "tests",
 ]
-
 asyncio_default_fixture_loop_scope = "function"
 
 # Keep temporary directories (created by `tmp_path`) for 2 recent runs only 
failed tests.
diff --git a/tests_common/test_utils/file_loading.py 
b/tests_common/test_utils/file_loading.py
new file mode 100644
index 00000000000..2653a1f41a3
--- /dev/null
+++ b/tests_common/test_utils/file_loading.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+import re
+from os.path import join
+
+
+def remove_license_header(content: str) -> str:
+    """Remove license header from the given content."""
+    # Define the pattern to match both block and single-line comments
+    pattern = r"(/\*.*?\*/)|(--.*?(\r?\n|\r))|(#.*?(\r?\n|\r))"
+
+    # Check if there is a license header at the beginning of the file
+    if re.match(pattern, content, flags=re.DOTALL):
+        # Use re.DOTALL to allow .* to match newline characters in block 
comments
+        return re.sub(pattern, "", content, flags=re.DOTALL).strip()
+    return content.strip()
+
+
+def load_json_from_resources(*args: str):
+    with open(join(*args), encoding="utf-8") as file:
+        return json.load(file)
+
+
+def load_file_from_resources(*args: str, mode="r", encoding="utf-8"):
+    with open(join(*args), mode=mode, encoding=encoding) as file:
+        if mode == "r":
+            return remove_license_header(file.read())
+        return file.read()

Reply via email to