This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4e17ecd3f89 Avoid imports from "providers" (#46801)
4e17ecd3f89 is described below
commit 4e17ecd3f892497e910f4e7df7ecb007a7f3d039
Author: Jarek Potiuk <[email protected]>
AuthorDate: Sun Feb 16 16:12:46 2025 +0100
Avoid imports from "providers" (#46801)
---
contributing-docs/11_provider_packages.rst | 14 +++++++
.../unit/apache/hive/transfers/test_s3_to_hive.py | 7 ++--
.../tests/unit/fab/auth_manager/models/test_db.py | 2 +-
.../unit/microsoft/azure/hooks/test_msgraph.py | 8 ++--
.../unit/microsoft/azure/operators/test_msgraph.py | 22 +++++++----
.../unit/microsoft/azure/sensors/test_msgraph.py | 8 ++--
.../azure/tests/unit/microsoft/azure/test_utils.py | 27 -------------
.../unit/microsoft/azure/triggers/test_msgraph.py | 22 +++++++----
.../tests/unit/microsoft/mssql/hooks/test_mssql.py | 6 ++-
pyproject.toml | 2 +-
tests_common/test_utils/file_loading.py | 46 ++++++++++++++++++++++
11 files changed, 107 insertions(+), 57 deletions(-)
diff --git a/contributing-docs/11_provider_packages.rst
b/contributing-docs/11_provider_packages.rst
index c7c28949b26..383f8d298eb 100644
--- a/contributing-docs/11_provider_packages.rst
+++ b/contributing-docs/11_provider_packages.rst
@@ -74,6 +74,9 @@ PROVIDER is the name of the provider package. It might be
single directory (goog
cases we have a nested structure one level down (``apache/cassandra``,
``apache/druid``, ``microsoft/winrm``,
``common.io`` for example).
+What are the pyproject.toml and provider.yaml files
+---------------------------------------------------
+
On top of the standard ``pyproject.toml`` file where we keep project
information,
we have ``provider.yaml`` file in the provider's module of the ``providers``.
@@ -92,6 +95,9 @@ not modify it - except updating dependencies, as your changes
will be lost.
Eventually we might migrate ``provider.yaml`` fully to ``pyproject.toml`` file
but it would require custom
``tool.airflow`` toml section to be added to the ``pyproject.toml`` file.
+How to manage provider's dependencies
+-------------------------------------
+
If you want to add dependencies to the provider, you should add them to the
corresponding ``pyproject.toml``
file.
@@ -115,6 +121,14 @@ package might be installed when breeze is restarted or by
your IDE or by running
or when you run ``pip install -e "./providers"`` or ``pip install -e
"./providers/<PROVIDER>"`` for the new
provider structure.
+How to reuse code between tests in different providers
+------------------------------------------------------
+
+When you develop providers, you might want to reuse some of the code between
tests in different providers.
+This is possible by placing the code in ``test_utils`` in the ``tests_common``
directory. The ``tests_common``
+module is automatically available in the ``sys.path`` when running tests for
the providers and you can
+import common code from there.
+
Chicken-egg providers
---------------------
diff --git
a/providers/apache/hive/tests/unit/apache/hive/transfers/test_s3_to_hive.py
b/providers/apache/hive/tests/unit/apache/hive/transfers/test_s3_to_hive.py
index c0f241bbe87..4d54b6248cf 100644
--- a/providers/apache/hive/tests/unit/apache/hive/transfers/test_s3_to_hive.py
+++ b/providers/apache/hive/tests/unit/apache/hive/transfers/test_s3_to_hive.py
@@ -29,10 +29,11 @@ from unittest import mock
import pytest
-import providers.microsoft.azure.tests.unit.microsoft.azure.test_utils
from airflow.exceptions import AirflowException
from airflow.providers.apache.hive.transfers.s3_to_hive import
S3ToHiveOperator, uncompress_file
+import tests_common.test_utils.file_loading
+
boto3 = pytest.importorskip("boto3")
moto = pytest.importorskip("moto")
logger = logging.getLogger(__name__)
@@ -218,10 +219,10 @@ class TestS3ToHiveTransfer:
# Upload the file into the Mocked S3 bucket
conn.upload_file(ip_fn, "bucket", self.s3_key + ext)
-
# file parameter to HiveCliHook.load_file is compared
# against expected file output
-
providers.microsoft.azure.tests.unit.microsoft.azure.test_utils.load_file.side_effect
= (
+
+
tests_common.test_utils.file_loading.load_file_from_resources.side_effect = (
lambda *args, **kwargs: self._load_file_side_effect(args,
op_fn, ext)
)
# Execute S3ToHiveTransfer
diff --git a/providers/fab/tests/unit/fab/auth_manager/models/test_db.py
b/providers/fab/tests/unit/fab/auth_manager/models/test_db.py
index c9ef74d3d62..f0920ebb151 100644
--- a/providers/fab/tests/unit/fab/auth_manager/models/test_db.py
+++ b/providers/fab/tests/unit/fab/auth_manager/models/test_db.py
@@ -24,7 +24,7 @@ from alembic.autogenerate import compare_metadata
from alembic.migration import MigrationContext
from sqlalchemy import MetaData
-import providers.fab.src.airflow.providers.fab as provider_fab
+import airflow.providers.fab as provider_fab
from airflow.settings import engine
from airflow.utils.db import (
compare_server_default,
diff --git
a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
index 6a88739b076..c7604a26665 100644
--- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
+++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import asyncio
import inspect
from json import JSONDecodeError
+from os.path import dirname
from typing import TYPE_CHECKING
from unittest.mock import Mock, patch
@@ -44,13 +45,12 @@ from airflow.providers.microsoft.azure.hooks.msgraph import
(
)
from unit.microsoft.azure.test_utils import (
get_airflow_connection,
- load_file,
- load_json,
mock_connection,
mock_json_response,
mock_response,
)
+from tests_common.test_utils.file_loading import load_file_from_resources,
load_json_from_resources
from tests_common.test_utils.providers import get_provider_min_airflow_version
if TYPE_CHECKING:
@@ -313,7 +313,7 @@ class TestKiotaRequestAdapterHook:
class TestResponseHandler:
def test_default_response_handler_when_json(self):
- users = load_json("resources", "users.json")
+ users = load_json_from_resources(dirname(__file__), "..", "resources",
"users.json")
response = mock_json_response(200, users)
actual =
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
@@ -329,7 +329,7 @@ class TestResponseHandler:
assert actual == {}
def test_default_response_handler_when_content(self):
- users = load_file("resources", "users.json").encode()
+ users = load_file_from_resources(dirname(__file__), "..", "resources",
"users.json").encode()
response = mock_response(200, users)
actual =
asyncio.run(DefaultResponseHandler().handle_response_async(response, None))
diff --git
a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_msgraph.py
b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_msgraph.py
index 1bec5670ca0..a26a7f83451 100644
---
a/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_msgraph.py
+++
b/providers/microsoft/azure/tests/unit/microsoft/azure/operators/test_msgraph.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import json
import locale
from base64 import b64encode
+from os.path import dirname
from typing import TYPE_CHECKING, Any
import pytest
@@ -27,8 +28,9 @@ from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.operators.msgraph import
MSGraphAsyncOperator
from airflow.triggers.base import TriggerEvent
from unit.microsoft.azure.base import Base
-from unit.microsoft.azure.test_utils import load_file, load_json,
mock_json_response, mock_response
+from unit.microsoft.azure.test_utils import mock_json_response, mock_response
+from tests_common.test_utils.file_loading import load_file_from_resources,
load_json_from_resources
from tests_common.test_utils.mock_context import mock_context
from tests_common.test_utils.operators.run_deferrable import execute_operator
from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS
@@ -44,8 +46,8 @@ if TYPE_CHECKING:
class TestMSGraphAsyncOperator(Base):
@pytest.mark.db_test
def test_execute(self):
- users = load_json("resources", "users.json")
- next_users = load_json("resources", "next_users.json")
+ users = load_json_from_resources(dirname(__file__), "..", "resources",
"users.json")
+ next_users = load_json_from_resources(dirname(__file__), "..",
"resources", "next_users.json")
response = mock_json_response(200, users, next_users)
with self.patch_hook_and_request_adapter(response):
@@ -72,7 +74,7 @@ class TestMSGraphAsyncOperator(Base):
@pytest.mark.db_test
def test_execute_when_do_xcom_push_is_false(self):
- users = load_json("resources", "users.json")
+ users = load_json_from_resources(dirname(__file__), "..", "resources",
"users.json")
users.pop("@odata.nextLink")
response = mock_json_response(200, users)
@@ -134,7 +136,9 @@ class TestMSGraphAsyncOperator(Base):
@pytest.mark.db_test
def test_execute_when_response_is_bytes(self):
- content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
+ content = load_file_from_resources(
+ dirname(__file__), "..", "resources", "dummy.pdf", mode="rb",
encoding=None
+ )
base64_encoded_content =
b64encode(content).decode(locale.getpreferredencoding())
drive_id = "82f9d24d-6891-4790-8b6d-f1b2a1d0ca22"
response = mock_response(200, content)
@@ -161,7 +165,9 @@ class TestMSGraphAsyncOperator(Base):
@pytest.mark.db_test
@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Lambda parameters
works in Airflow >= 2.10.0")
def test_execute_with_lambda_parameter_when_response_is_bytes(self):
- content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
+ content = load_file_from_resources(
+ dirname(__file__), "..", "resources", "dummy.pdf", mode="rb",
encoding=None
+ )
base64_encoded_content =
b64encode(content).decode(locale.getpreferredencoding())
drive_id = "82f9d24d-6891-4790-8b6d-f1b2a1d0ca22"
response = mock_response(200, content)
@@ -202,7 +208,7 @@ class TestMSGraphAsyncOperator(Base):
url="users",
)
context = mock_context(task=operator)
- response = load_json("resources", "users.json")
+ response = load_json_from_resources(dirname(__file__), "..",
"resources", "users.json")
next_link, query_parameters = MSGraphAsyncOperator.paginate(operator,
response, context)
assert next_link == response["@odata.nextLink"]
@@ -216,7 +222,7 @@ class TestMSGraphAsyncOperator(Base):
query_parameters={"$top": 12},
)
context = mock_context(task=operator)
- response = load_json("resources", "users.json")
+ response = load_json_from_resources(dirname(__file__), "..",
"resources", "users.json")
response["@odata.count"] = 100
url, query_parameters = MSGraphAsyncOperator.paginate(operator,
response, context)
diff --git
a/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_msgraph.py
b/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_msgraph.py
index d08e3d16103..5f42a7c16ab 100644
---
a/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_msgraph.py
+++
b/providers/microsoft/azure/tests/unit/microsoft/azure/sensors/test_msgraph.py
@@ -18,21 +18,23 @@ from __future__ import annotations
import json
from datetime import datetime
+from os.path import dirname
import pytest
from airflow.providers.microsoft.azure.sensors.msgraph import MSGraphSensor
from airflow.triggers.base import TriggerEvent
from unit.microsoft.azure.base import Base
-from unit.microsoft.azure.test_utils import load_json, mock_json_response
+from unit.microsoft.azure.test_utils import mock_json_response
+from tests_common.test_utils.file_loading import load_json_from_resources
from tests_common.test_utils.operators.run_deferrable import execute_operator
from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS
class TestMSGraphSensor(Base):
def test_execute(self):
- status = load_json("resources", "status.json")
+ status = load_json_from_resources(dirname(__file__), "..",
"resources", "status.json")
response = mock_json_response(200, *status)
with self.patch_hook_and_request_adapter(response):
@@ -65,7 +67,7 @@ class TestMSGraphSensor(Base):
@pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Lambda parameters
works in Airflow >= 2.10.0")
def test_execute_with_lambda_parameter(self):
- status = load_json("resources", "status.json")
+ status = load_json_from_resources(dirname(__file__), "..",
"resources", "status.json")
response = mock_json_response(200, *status)
with self.patch_hook_and_request_adapter(response):
diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/test_utils.py
b/providers/microsoft/azure/tests/unit/microsoft/azure/test_utils.py
index c246444b44e..54a82c67a7b 100644
--- a/providers/microsoft/azure/tests/unit/microsoft/azure/test_utils.py
+++ b/providers/microsoft/azure/tests/unit/microsoft/azure/test_utils.py
@@ -17,10 +17,7 @@
from __future__ import annotations
-import json
-import re
from json import JSONDecodeError
-from os.path import dirname, join
from typing import Any
from unittest import mock
from unittest.mock import MagicMock
@@ -237,27 +234,3 @@ def mock_response(status_code, content: Any = None,
headers: dict | None = None)
response.content = content
response.json.side_effect = JSONDecodeError("", "", 0)
return response
-
-
-def remove_license_header(content: str) -> str:
- """Remove license header from the given content."""
- # Define the pattern to match both block and single-line comments
- pattern = r"(/\*.*?\*/)|(--.*?(\r?\n|\r))|(#.*?(\r?\n|\r))"
-
- # Check if there is a license header at the beginning of the file
- if re.match(pattern, content, flags=re.DOTALL):
- # Use re.DOTALL to allow .* to match newline characters in block
comments
- return re.sub(pattern, "", content, flags=re.DOTALL).strip()
- return content.strip()
-
-
-def load_json(*args: str):
- with open(join(dirname(__file__), *args), encoding="utf-8") as file:
- return json.load(file)
-
-
-def load_file(*args: str, mode="r", encoding="utf-8"):
- with open(join(dirname(__file__), *args), mode=mode, encoding=encoding) as
file:
- if mode == "r":
- return remove_license_header(file.read())
- return file.read()
diff --git
a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py
b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py
index d15edb25ab6..8d8f9e84d53 100644
---
a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py
+++
b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py
@@ -20,6 +20,7 @@ import json
import locale
from base64 import b64decode, b64encode
from datetime import datetime
+from os.path import dirname
from unittest.mock import patch
from uuid import uuid4
@@ -36,18 +37,17 @@ from airflow.triggers.base import TriggerEvent
from unit.microsoft.azure.base import Base
from unit.microsoft.azure.test_utils import (
get_airflow_connection,
- load_file,
- load_json,
mock_json_response,
mock_response,
)
+from tests_common.test_utils.file_loading import load_file_from_resources,
load_json_from_resources
from tests_common.test_utils.operators.run_deferrable import run_trigger
class TestMSGraphTrigger(Base):
def test_run_when_valid_response(self):
- users = load_json("resources", "users.json")
+ users = load_json_from_resources(dirname(__file__), "..", "resources",
"users.json")
response = mock_json_response(200, users)
with self.patch_hook_and_request_adapter(response):
@@ -83,7 +83,9 @@ class TestMSGraphTrigger(Base):
assert actual.payload["message"] == ""
def test_run_when_response_is_bytes(self):
- content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
+ content = load_file_from_resources(
+ dirname(__file__), "..", "resources", "dummy.pdf", mode="rb",
encoding=None
+ )
base64_encoded_content =
b64encode(content).decode(locale.getpreferredencoding())
response = mock_response(200, content)
@@ -138,7 +140,9 @@ class TestMSGraphTrigger(Base):
class TestResponseSerializer:
def test_serialize_when_bytes_then_base64_encoded(self):
- response = load_file("resources", "dummy.pdf", mode="rb",
encoding=None)
+ response = load_file_from_resources(
+ dirname(__file__), "..", "resources", "dummy.pdf", mode="rb",
encoding=None
+ )
content = b64encode(response).decode(locale.getpreferredencoding())
actual = ResponseSerializer().serialize(response)
@@ -163,15 +167,17 @@ class TestResponseSerializer:
)
def test_deserialize_when_json(self):
- response = load_file("resources", "users.json")
+ response = load_file_from_resources(dirname(__file__), "..",
"resources", "users.json")
actual = ResponseSerializer().deserialize(response)
assert isinstance(actual, dict)
- assert actual == load_json("resources", "users.json")
+ assert actual == load_json_from_resources(dirname(__file__), "..",
"resources", "users.json")
def test_deserialize_when_base64_encoded_string(self):
- content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)
+ content = load_file_from_resources(
+ dirname(__file__), "..", "resources", "dummy.pdf", mode="rb",
encoding=None
+ )
response = b64encode(content).decode(locale.getpreferredencoding())
actual = ResponseSerializer().deserialize(response)
diff --git
a/providers/microsoft/mssql/tests/unit/microsoft/mssql/hooks/test_mssql.py
b/providers/microsoft/mssql/tests/unit/microsoft/mssql/hooks/test_mssql.py
index 43332992e11..2f9e81f66ac 100644
--- a/providers/microsoft/mssql/tests/unit/microsoft/mssql/hooks/test_mssql.py
+++ b/providers/microsoft/mssql/tests/unit/microsoft/mssql/hooks/test_mssql.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+from os.path import dirname
from unittest import mock
import pytest
@@ -25,7 +26,8 @@ import sqlalchemy
from airflow.configuration import conf
from airflow.models import Connection
from airflow.providers.microsoft.mssql.dialects.mssql import MsSqlDialect
-from unit.microsoft.mssql.test_utils import load_file
+
+from tests_common.test_utils.file_loading import load_file_from_resources
try:
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
@@ -286,7 +288,7 @@ class TestMsSqlHook:
],
replace=True,
)
- assert sql == load_file("resources", "replace.sql")
+ assert sql == load_file_from_resources(dirname(__file__), "..",
"resources", "replace.sql")
def test_dialect_name(self):
hook = MsSqlHook()
diff --git a/pyproject.toml b/pyproject.toml
index a00c53dd503..65f7f6d0539 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -469,6 +469,7 @@ banned-module-level-imports = ["numpy", "pandas"]
"sqlalchemy.ext.declarative.as_declarative".msg = "Use
`sqlalchemy.orm.as_declarative`. Moved in SQLAlchemy 2.0"
"sqlalchemy.ext.declarative.has_inherited_table".msg = "Use
`sqlalchemy.orm.has_inherited_table`. Moved in SQLAlchemy 2.0"
"sqlalchemy.ext.declarative.synonym_for".msg = "Use
`sqlalchemy.orm.synonym_for`. Moved in SQLAlchemy 2.0"
+"providers".msg = "You should not import 'providers' as a Python module.
Imports in providers should be done starting from 'src' or `tests' folders, for
example 'from airflow.providers.airbyte' or 'from unit.airbyte' or 'from
system.airbyte'"
[tool.ruff.lint.flake8-type-checking]
exempt-modules = ["typing", "typing_extensions"]
@@ -542,7 +543,6 @@ python_files = [
testpaths = [
"tests",
]
-
asyncio_default_fixture_loop_scope = "function"
# Keep temporary directories (created by `tmp_path`) for 2 recent runs only
failed tests.
diff --git a/tests_common/test_utils/file_loading.py
b/tests_common/test_utils/file_loading.py
new file mode 100644
index 00000000000..2653a1f41a3
--- /dev/null
+++ b/tests_common/test_utils/file_loading.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+import re
+from os.path import join
+
+
+def remove_license_header(content: str) -> str:
+ """Remove license header from the given content."""
+ # Define the pattern to match both block and single-line comments
+ pattern = r"(/\*.*?\*/)|(--.*?(\r?\n|\r))|(#.*?(\r?\n|\r))"
+
+ # Check if there is a license header at the beginning of the file
+ if re.match(pattern, content, flags=re.DOTALL):
+ # Use re.DOTALL to allow .* to match newline characters in block
comments
+ return re.sub(pattern, "", content, flags=re.DOTALL).strip()
+ return content.strip()
+
+
+def load_json_from_resources(*args: str):
+ with open(join(*args), encoding="utf-8") as file:
+ return json.load(file)
+
+
+def load_file_from_resources(*args: str, mode="r", encoding="utf-8"):
+ with open(join(*args), mode=mode, encoding=encoding) as file:
+ if mode == "r":
+ return remove_license_header(file.read())
+ return file.read()