Copilot commented on code in PR #64941:
URL: https://github.com/apache/airflow/pull/64941#discussion_r3066493766
##########
providers/common/sql/src/airflow/providers/common/sql/datafusion/object_storage_provider.py:
##########
@@ -70,18 +44,41 @@ def get_scheme(self) -> str:
return "file://"
+_STORAGE_TYPE_PROVIDER_HINTS: dict[str, str] = {
+ "s3": "apache-airflow-providers-amazon[datafusion]",
+}
+
+
def get_object_storage_provider(storage_type: StorageType) ->
ObjectStorageProvider:
"""Get an object storage provider based on the storage type."""
- # TODO: Add support for GCS, Azure, HTTP:
https://datafusion.apache.org/python/autoapi/datafusion/object_store/index.html
- providers: dict[StorageType, type] = {
- StorageType.S3: S3ObjectStorageProvider,
- StorageType.LOCAL: LocalObjectStorageProvider,
- }
-
- if storage_type not in providers:
- raise ValueError(
- f"Unsupported storage type: {storage_type}. Supported types:
{list(providers.keys())}"
+ if storage_type == StorageType.LOCAL:
+ return LocalObjectStorageProvider()
+
+ from airflow.providers_manager import ProvidersManager
+
+ registry = ProvidersManager().object_storage_providers
+ type_key = storage_type.value
Review Comment:
`get_object_storage_provider()` assumes `storage_type` is a `StorageType`
enum and will raise an `AttributeError` (on `.value`) if a caller passes an
invalid type (e.g. a plain string). Previously this surfaced as a `ValueError`
with a clearer message; consider adding an explicit `isinstance(storage_type,
StorageType)` check and raising a `ValueError` for unsupported/invalid inputs
to keep the API behavior predictable.
##########
providers/common/sql/tests/unit/common/sql/datafusion/test_object_storage_provider.py:
##########
@@ -16,59 +16,111 @@
# under the License.
from __future__ import annotations
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
import pytest
-from airflow.providers.common.sql.config import ConnectionConfig, StorageType
-from airflow.providers.common.sql.datafusion.exceptions import
ObjectStoreCreationException
+from airflow.providers.common.sql.config import StorageType
from airflow.providers.common.sql.datafusion.object_storage_provider import (
LocalObjectStorageProvider,
- S3ObjectStorageProvider,
get_object_storage_provider,
)
-class TestObjectStorageProvider:
-
@patch("airflow.providers.common.sql.datafusion.object_storage_provider.AmazonS3")
- def test_s3_provider_success(self, mock_s3):
- provider = S3ObjectStorageProvider()
- connection_config = ConnectionConfig(
- conn_id="aws_default",
- credentials={"access_key_id": "fake_key", "secret_access_key":
"fake_secret"},
- )
+class TestLocalObjectStorageProvider:
+ @patch(
+
"airflow.providers.common.sql.datafusion.object_storage_provider.LocalFileSystem",
+ autospec=True,
+ )
+ def test_local_provider(self, mock_local):
+ provider = LocalObjectStorageProvider()
+ assert provider.get_storage_type == StorageType.LOCAL
+ assert provider.get_scheme() == "file://"
+ local_store = provider.create_object_store("file://path")
+ assert local_store == mock_local.return_value
+
- store = provider.create_object_store("s3://demo-data/path",
connection_config)
+class TestGetObjectStorageProvider:
+ def test_returns_local_provider_directly(self):
+ provider = get_object_storage_provider(StorageType.LOCAL)
+ assert isinstance(provider, LocalObjectStorageProvider)
- mock_s3.assert_called_once_with(
- access_key_id="fake_key", secret_access_key="fake_secret",
bucket_name="demo-data"
+
@patch("airflow.providers.common.sql.datafusion.object_storage_provider.import_string",
autospec=True)
+ @patch("airflow.providers_manager.ProvidersManager", autospec=True)
+ def test_resolves_s3_via_registry(self, mock_pm_cls, mock_import_string):
+ mock_provider_cls = MagicMock()
+ mock_import_string.return_value = mock_provider_cls
+
+ mock_pm_cls.return_value.object_storage_providers = {
+ "s3": MagicMock(
+
provider_class_name="airflow.providers.amazon.aws.datafusion.object_storage.S3ObjectStorageProvider",
+ ),
+ }
Review Comment:
Several `MagicMock()` instances here are created without `spec`/`autospec`,
which can mask real attribute/method typos in tests. Prefer
`MagicMock(spec=...)` (or using `autospec=True` where patching) so the mocks
match the shape of the real objects being replaced.
##########
providers/common/sql/src/airflow/providers/common/sql/datafusion/object_storage_provider.py:
##########
@@ -70,18 +44,41 @@ def get_scheme(self) -> str:
return "file://"
+_STORAGE_TYPE_PROVIDER_HINTS: dict[str, str] = {
+ "s3": "apache-airflow-providers-amazon[datafusion]",
+}
+
+
def get_object_storage_provider(storage_type: StorageType) ->
ObjectStorageProvider:
"""Get an object storage provider based on the storage type."""
- # TODO: Add support for GCS, Azure, HTTP:
https://datafusion.apache.org/python/autoapi/datafusion/object_store/index.html
- providers: dict[StorageType, type] = {
- StorageType.S3: S3ObjectStorageProvider,
- StorageType.LOCAL: LocalObjectStorageProvider,
- }
-
- if storage_type not in providers:
- raise ValueError(
- f"Unsupported storage type: {storage_type}. Supported types:
{list(providers.keys())}"
+ if storage_type == StorageType.LOCAL:
+ return LocalObjectStorageProvider()
+
+ from airflow.providers_manager import ProvidersManager
+
+ registry = ProvidersManager().object_storage_providers
+ type_key = storage_type.value
+
+ if type_key in registry:
+ provider_cls = import_string(registry[type_key].provider_class_name)
+ return provider_cls()
+
+ hint = _STORAGE_TYPE_PROVIDER_HINTS.get(type_key, "the appropriate
provider package")
Review Comment:
When a storage type is registered, `import_string(...)` can still fail (e.g.
provider installed without the required extra like
`apache-airflow-providers-amazon[datafusion]`, or a bad class path). Right now
that ImportError will bubble up and bypass the friendly install-hint
ValueError; consider catching `ImportError` (and possibly `Exception`) around
`import_string`/instantiation and raising a `ValueError` that preserves the
underlying error while still including the install/upgrade hint.
```suggestion
hint = _STORAGE_TYPE_PROVIDER_HINTS.get(type_key, "the appropriate
provider package")
if type_key in registry:
try:
provider_cls =
import_string(registry[type_key].provider_class_name)
return provider_cls()
except ImportError as exc:
raise ValueError(
f"Failed to import ObjectStorageProvider for storage type
'{type_key}'. "
f"Install or upgrade {hint}."
) from exc
except Exception as exc:
raise ValueError(
f"Failed to initialize ObjectStorageProvider for storage
type '{type_key}'. "
f"Install or upgrade {hint}."
) from exc
```
##########
providers/common/sql/tests/unit/common/sql/datafusion/test_object_storage_provider.py:
##########
@@ -16,59 +16,111 @@
# under the License.
from __future__ import annotations
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
import pytest
-from airflow.providers.common.sql.config import ConnectionConfig, StorageType
-from airflow.providers.common.sql.datafusion.exceptions import
ObjectStoreCreationException
+from airflow.providers.common.sql.config import StorageType
from airflow.providers.common.sql.datafusion.object_storage_provider import (
LocalObjectStorageProvider,
- S3ObjectStorageProvider,
get_object_storage_provider,
)
-class TestObjectStorageProvider:
-
@patch("airflow.providers.common.sql.datafusion.object_storage_provider.AmazonS3")
- def test_s3_provider_success(self, mock_s3):
- provider = S3ObjectStorageProvider()
- connection_config = ConnectionConfig(
- conn_id="aws_default",
- credentials={"access_key_id": "fake_key", "secret_access_key":
"fake_secret"},
- )
+class TestLocalObjectStorageProvider:
+ @patch(
+
"airflow.providers.common.sql.datafusion.object_storage_provider.LocalFileSystem",
+ autospec=True,
+ )
+ def test_local_provider(self, mock_local):
+ provider = LocalObjectStorageProvider()
+ assert provider.get_storage_type == StorageType.LOCAL
+ assert provider.get_scheme() == "file://"
+ local_store = provider.create_object_store("file://path")
+ assert local_store == mock_local.return_value
+
- store = provider.create_object_store("s3://demo-data/path",
connection_config)
+class TestGetObjectStorageProvider:
+ def test_returns_local_provider_directly(self):
+ provider = get_object_storage_provider(StorageType.LOCAL)
+ assert isinstance(provider, LocalObjectStorageProvider)
- mock_s3.assert_called_once_with(
- access_key_id="fake_key", secret_access_key="fake_secret",
bucket_name="demo-data"
+
@patch("airflow.providers.common.sql.datafusion.object_storage_provider.import_string",
autospec=True)
+ @patch("airflow.providers_manager.ProvidersManager", autospec=True)
+ def test_resolves_s3_via_registry(self, mock_pm_cls, mock_import_string):
+ mock_provider_cls = MagicMock()
+ mock_import_string.return_value = mock_provider_cls
+
+ mock_pm_cls.return_value.object_storage_providers = {
+ "s3": MagicMock(
+
provider_class_name="airflow.providers.amazon.aws.datafusion.object_storage.S3ObjectStorageProvider",
+ ),
+ }
+
+ provider = get_object_storage_provider(StorageType.S3)
+
+ mock_import_string.assert_called_once_with(
+
"airflow.providers.amazon.aws.datafusion.object_storage.S3ObjectStorageProvider"
)
- assert store == mock_s3.return_value
- assert provider.get_storage_type == StorageType.S3
- assert provider.get_scheme() == "s3://"
+ assert provider == mock_provider_cls.return_value
+
+ @patch("airflow.providers_manager.ProvidersManager", autospec=True)
+ def test_unregistered_storage_type_raises(self, mock_pm_cls):
+ mock_pm_cls.return_value.object_storage_providers = {}
+
+ with pytest.raises(ValueError, match="No ObjectStorageProvider
registered.*Install or upgrade"):
+ get_object_storage_provider(StorageType.S3)
+
+ def test_error_message_includes_install_hint_for_s3(self):
+ with patch("airflow.providers_manager.ProvidersManager",
autospec=True) as mock_pm_cls:
+ mock_pm_cls.return_value.object_storage_providers = {}
+
+ with pytest.raises(ValueError,
match="apache-airflow-providers-amazon"):
+ get_object_storage_provider(StorageType.S3)
- def test_s3_provider_failure(self):
- provider = S3ObjectStorageProvider()
- connection_config = ConnectionConfig(conn_id="aws_default")
+ def test_no_amazon_imports_at_module_level(self):
+ """Verify common-sql no longer statically imports amazon provider code
at the top level."""
+ import airflow.providers.common.sql.datafusion.object_storage_provider
as mod
- with patch(
-
"airflow.providers.common.sql.datafusion.object_storage_provider.AmazonS3",
- side_effect=Exception("Error"),
+ top_level_names = [
+ name
+ for name, obj in vars(mod).items()
+ if not name.startswith("_")
+ and hasattr(obj, "__module__")
+ and "amazon" in getattr(obj, "__module__", "")
+ ]
+ assert top_level_names == [], f"Amazon symbols found at module level:
{top_level_names}"
+
+
+class TestS3DeprecationShim:
+ def test_old_import_path_emits_deprecation_warning(self):
+ """Importing S3ObjectStorageProvider from the old path still works but
warns."""
+ pytest.importorskip("airflow.providers.amazon")
+ import airflow.providers.common.sql.datafusion.object_storage_provider
as mod
+
+ with pytest.warns(
+ match="Import it from airflow.providers.amazon",
):
- with pytest.raises(ObjectStoreCreationException, match="Failed to
create S3 object store"):
- provider.create_object_store("s3://demo-data/path",
connection_config)
+ cls = mod.S3ObjectStorageProvider
Review Comment:
These deprecation-shim tests use `pytest.warns(match=...)` without
specifying the expected warning class. To avoid passing on unrelated warnings,
assert the warning type explicitly (e.g. `AirflowProviderDeprecationWarning`)
in addition to the message match.
##########
providers/amazon/tests/unit/amazon/aws/datafusion/test_object_storage.py:
##########
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.models import Connection
+from airflow.providers.common.sql.config import ConnectionConfig, StorageType
+from airflow.providers.common.sql.datafusion.exceptions import
ObjectStoreCreationException
+
+
+class TestS3ObjectStorageProvider:
+ """Tests for S3ObjectStorageProvider in the amazon provider package."""
+
+ @pytest.fixture(autouse=True)
+ def setup_connections(self, create_connection_without_db):
+ create_connection_without_db(
+ Connection(
+ conn_id="aws_default",
+ conn_type="aws",
+ login="fake_id",
+ password="fake_secret",
+ extra='{"region": "us-east-1"}',
+ )
+ )
+
+ @patch(
+ "airflow.providers.amazon.aws.datafusion.object_storage.AmazonS3",
+ autospec=True,
+ )
+ @patch(
+
"airflow.providers.amazon.aws.datafusion.object_storage.AwsGenericHook",
+ autospec=True,
+ )
+ def test_s3_provider_with_login_password(self, mock_hook_cls, mock_s3):
+ """Login/password on the connection override hook credentials."""
+ from airflow.providers.amazon.aws.datafusion.object_storage import
S3ObjectStorageProvider
+
+ mock_creds = MagicMock()
+ mock_creds.access_key = "hook_key"
+ mock_creds.secret_key = "hook_secret"
+ mock_creds.token = None
+ mock_hook_cls.return_value.get_credentials.return_value = mock_creds
+
+ provider = S3ObjectStorageProvider()
+ config = ConnectionConfig(conn_id="aws_default")
+
+ store = provider.create_object_store("s3://demo-data/path",
connection_config=config)
+
+ mock_s3.assert_called_once_with(
+ access_key_id="fake_id",
+ secret_access_key="fake_secret",
+ region="us-east-1",
+ bucket_name="demo-data",
+ )
+ assert store == mock_s3.return_value
+ assert provider.get_storage_type == StorageType.S3
+ assert provider.get_scheme() == "s3://"
+
+ @patch(
+ "airflow.providers.amazon.aws.datafusion.object_storage.AmazonS3",
+ autospec=True,
+ )
+ @patch(
+
"airflow.providers.amazon.aws.datafusion.object_storage.AwsGenericHook",
+ autospec=True,
+ )
+ def test_s3_provider_falls_back_to_hook_credentials(self, mock_hook_cls,
mock_s3):
+ """When login/password are empty, hook credentials are used."""
+ from airflow.providers.amazon.aws.datafusion.object_storage import
S3ObjectStorageProvider
+
+ mock_creds = MagicMock()
+ mock_creds.access_key = "hook_key"
+ mock_creds.secret_key = "hook_secret"
+ mock_creds.token = "session_tok"
+ mock_hook_cls.return_value.get_credentials.return_value = mock_creds
+
+ provider = S3ObjectStorageProvider()
+ config = ConnectionConfig(conn_id="aws_no_login")
+
+ with patch(
+
"airflow.providers.amazon.aws.datafusion.object_storage.BaseHook.get_connection",
+ return_value=Connection(
+ conn_id="aws_no_login",
+ conn_type="aws",
+ extra='{"endpoint": "http://localhost:4566"}',
+ ),
+ ):
+ store = provider.create_object_store("s3://bucket/path",
connection_config=config)
+
+ mock_s3.assert_called_once_with(
+ access_key_id="hook_key",
+ secret_access_key="hook_secret",
+ session_token="session_tok",
+ endpoint="http://localhost:4566",
+ bucket_name="bucket",
+ )
+ assert store == mock_s3.return_value
+
+ @patch(
+ "airflow.providers.amazon.aws.datafusion.object_storage.AmazonS3",
+ autospec=True,
+ )
+ @patch(
+
"airflow.providers.amazon.aws.datafusion.object_storage.AwsGenericHook",
+ autospec=True,
+ )
+ def test_s3_provider_session_token(self, mock_hook_cls, mock_s3):
+ """Session token from hook is forwarded when present."""
+ from airflow.providers.amazon.aws.datafusion.object_storage import
S3ObjectStorageProvider
+
+ mock_creds = MagicMock()
+ mock_creds.access_key = "hook_key"
+ mock_creds.secret_key = "hook_secret"
+ mock_creds.token = "my_session_token"
+ mock_hook_cls.return_value.get_credentials.return_value = mock_creds
Review Comment:
Same issue here: `mock_creds` is an unspec'd `MagicMock()`. Consider using a
`spec`/`autospec`-based mock or an explicit object to ensure the test fails if
unexpected attributes are accessed.
##########
providers/amazon/tests/unit/amazon/aws/datafusion/test_object_storage.py:
##########
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.models import Connection
+from airflow.providers.common.sql.config import ConnectionConfig, StorageType
+from airflow.providers.common.sql.datafusion.exceptions import
ObjectStoreCreationException
+
+
+class TestS3ObjectStorageProvider:
+ """Tests for S3ObjectStorageProvider in the amazon provider package."""
+
+ @pytest.fixture(autouse=True)
+ def setup_connections(self, create_connection_without_db):
+ create_connection_without_db(
+ Connection(
+ conn_id="aws_default",
+ conn_type="aws",
+ login="fake_id",
+ password="fake_secret",
+ extra='{"region": "us-east-1"}',
+ )
+ )
+
+ @patch(
+ "airflow.providers.amazon.aws.datafusion.object_storage.AmazonS3",
+ autospec=True,
+ )
+ @patch(
+
"airflow.providers.amazon.aws.datafusion.object_storage.AwsGenericHook",
+ autospec=True,
+ )
+ def test_s3_provider_with_login_password(self, mock_hook_cls, mock_s3):
+ """Login/password on the connection override hook credentials."""
+ from airflow.providers.amazon.aws.datafusion.object_storage import
S3ObjectStorageProvider
+
+ mock_creds = MagicMock()
+ mock_creds.access_key = "hook_key"
+ mock_creds.secret_key = "hook_secret"
+ mock_creds.token = None
+ mock_hook_cls.return_value.get_credentials.return_value = mock_creds
Review Comment:
`mock_creds` is a `MagicMock()` without a `spec`, which can hide attribute
mistakes in the test (e.g. typos in `access_key`/`secret_key`/`token`). Prefer
`MagicMock(spec=...)` or a small simple object (e.g. `types.SimpleNamespace`)
with explicit attributes.
##########
providers/common/sql/tests/unit/common/sql/datafusion/test_object_storage_provider.py:
##########
@@ -16,59 +16,111 @@
# under the License.
from __future__ import annotations
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
import pytest
-from airflow.providers.common.sql.config import ConnectionConfig, StorageType
-from airflow.providers.common.sql.datafusion.exceptions import
ObjectStoreCreationException
+from airflow.providers.common.sql.config import StorageType
from airflow.providers.common.sql.datafusion.object_storage_provider import (
LocalObjectStorageProvider,
- S3ObjectStorageProvider,
get_object_storage_provider,
)
-class TestObjectStorageProvider:
-
@patch("airflow.providers.common.sql.datafusion.object_storage_provider.AmazonS3")
- def test_s3_provider_success(self, mock_s3):
- provider = S3ObjectStorageProvider()
- connection_config = ConnectionConfig(
- conn_id="aws_default",
- credentials={"access_key_id": "fake_key", "secret_access_key":
"fake_secret"},
- )
+class TestLocalObjectStorageProvider:
+ @patch(
+
"airflow.providers.common.sql.datafusion.object_storage_provider.LocalFileSystem",
+ autospec=True,
+ )
+ def test_local_provider(self, mock_local):
+ provider = LocalObjectStorageProvider()
+ assert provider.get_storage_type == StorageType.LOCAL
+ assert provider.get_scheme() == "file://"
+ local_store = provider.create_object_store("file://path")
+ assert local_store == mock_local.return_value
+
- store = provider.create_object_store("s3://demo-data/path",
connection_config)
+class TestGetObjectStorageProvider:
+ def test_returns_local_provider_directly(self):
+ provider = get_object_storage_provider(StorageType.LOCAL)
+ assert isinstance(provider, LocalObjectStorageProvider)
- mock_s3.assert_called_once_with(
- access_key_id="fake_key", secret_access_key="fake_secret",
bucket_name="demo-data"
+
@patch("airflow.providers.common.sql.datafusion.object_storage_provider.import_string",
autospec=True)
+ @patch("airflow.providers_manager.ProvidersManager", autospec=True)
+ def test_resolves_s3_via_registry(self, mock_pm_cls, mock_import_string):
+ mock_provider_cls = MagicMock()
+ mock_import_string.return_value = mock_provider_cls
+
+ mock_pm_cls.return_value.object_storage_providers = {
+ "s3": MagicMock(
+
provider_class_name="airflow.providers.amazon.aws.datafusion.object_storage.S3ObjectStorageProvider",
+ ),
+ }
+
+ provider = get_object_storage_provider(StorageType.S3)
+
+ mock_import_string.assert_called_once_with(
+
"airflow.providers.amazon.aws.datafusion.object_storage.S3ObjectStorageProvider"
)
- assert store == mock_s3.return_value
- assert provider.get_storage_type == StorageType.S3
- assert provider.get_scheme() == "s3://"
+ assert provider == mock_provider_cls.return_value
+
+ @patch("airflow.providers_manager.ProvidersManager", autospec=True)
+ def test_unregistered_storage_type_raises(self, mock_pm_cls):
+ mock_pm_cls.return_value.object_storage_providers = {}
+
+ with pytest.raises(ValueError, match="No ObjectStorageProvider
registered.*Install or upgrade"):
+ get_object_storage_provider(StorageType.S3)
+
+ def test_error_message_includes_install_hint_for_s3(self):
+ with patch("airflow.providers_manager.ProvidersManager",
autospec=True) as mock_pm_cls:
+ mock_pm_cls.return_value.object_storage_providers = {}
+
+ with pytest.raises(ValueError,
match="apache-airflow-providers-amazon"):
+ get_object_storage_provider(StorageType.S3)
- def test_s3_provider_failure(self):
- provider = S3ObjectStorageProvider()
- connection_config = ConnectionConfig(conn_id="aws_default")
+ def test_no_amazon_imports_at_module_level(self):
+ """Verify common-sql no longer statically imports amazon provider code
at the top level."""
+ import airflow.providers.common.sql.datafusion.object_storage_provider
as mod
- with patch(
-
"airflow.providers.common.sql.datafusion.object_storage_provider.AmazonS3",
- side_effect=Exception("Error"),
+ top_level_names = [
+ name
+ for name, obj in vars(mod).items()
+ if not name.startswith("_")
+ and hasattr(obj, "__module__")
+ and "amazon" in getattr(obj, "__module__", "")
+ ]
+ assert top_level_names == [], f"Amazon symbols found at module level:
{top_level_names}"
+
+
+class TestS3DeprecationShim:
+ def test_old_import_path_emits_deprecation_warning(self):
+ """Importing S3ObjectStorageProvider from the old path still works but
warns."""
+ pytest.importorskip("airflow.providers.amazon")
+ import airflow.providers.common.sql.datafusion.object_storage_provider
as mod
+
+ with pytest.warns(
+ match="Import it from airflow.providers.amazon",
):
- with pytest.raises(ObjectStoreCreationException, match="Failed to
create S3 object store"):
- provider.create_object_store("s3://demo-data/path",
connection_config)
+ cls = mod.S3ObjectStorageProvider
-
@patch("airflow.providers.common.sql.datafusion.object_storage_provider.LocalFileSystem")
- def test_local_provider(self, mock_local):
- provider = LocalObjectStorageProvider()
- assert provider.get_storage_type == StorageType.LOCAL
- assert provider.get_scheme() == "file://"
- local_store = provider.create_object_store("file://path")
- assert local_store == mock_local.return_value
+ assert cls.__name__ == "S3ObjectStorageProvider"
+
+ def test_old_import_path_returns_same_class(self):
+ """The shim re-exports the exact same class from the new location."""
+ pytest.importorskip("airflow.providers.amazon")
+ import airflow.providers.common.sql.datafusion.object_storage_provider
as mod
+
+ with pytest.warns(
+ match="Import it from airflow.providers.amazon",
+ ):
+ old_cls = mod.S3ObjectStorageProvider
Review Comment:
Same as above: `pytest.warns(match=...)` should specify the warning category
to ensure the shim is emitting `AirflowProviderDeprecationWarning` (and not
just any warning that happens to match the message).
##########
providers/common/sql/pyproject.toml:
##########
@@ -100,6 +97,9 @@ dependencies = [
"apache.iceberg" = [
"apache-airflow-providers-apache-iceberg"
]
+"amazon" = [
+ "apache-airflow-providers-amazon"
+]
Review Comment:
The PR description says the amazon dependency is removed from `common-sql`,
but the `amazon` optional-dependency is still present here. Also, since
S3/DataFusion now requires `apache-airflow-providers-amazon[datafusion]`,
depending on plain `apache-airflow-providers-amazon` risks installs where the
registry entry exists but importing the provider fails due to missing
`datafusion`. Either remove this optional dependency (and update dev deps
accordingly) or change it to `apache-airflow-providers-amazon[datafusion]` and
align the PR description.
##########
providers/amazon/tests/unit/amazon/aws/datafusion/test_object_storage.py:
##########
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.models import Connection
+from airflow.providers.common.sql.config import ConnectionConfig, StorageType
+from airflow.providers.common.sql.datafusion.exceptions import
ObjectStoreCreationException
+
+
+class TestS3ObjectStorageProvider:
+ """Tests for S3ObjectStorageProvider in the amazon provider package."""
+
+ @pytest.fixture(autouse=True)
+ def setup_connections(self, create_connection_without_db):
+ create_connection_without_db(
+ Connection(
+ conn_id="aws_default",
+ conn_type="aws",
+ login="fake_id",
+ password="fake_secret",
+ extra='{"region": "us-east-1"}',
+ )
+ )
+
+ @patch(
+ "airflow.providers.amazon.aws.datafusion.object_storage.AmazonS3",
+ autospec=True,
+ )
+ @patch(
+
"airflow.providers.amazon.aws.datafusion.object_storage.AwsGenericHook",
+ autospec=True,
+ )
+ def test_s3_provider_with_login_password(self, mock_hook_cls, mock_s3):
+ """Login/password on the connection override hook credentials."""
+ from airflow.providers.amazon.aws.datafusion.object_storage import
S3ObjectStorageProvider
+
+ mock_creds = MagicMock()
+ mock_creds.access_key = "hook_key"
+ mock_creds.secret_key = "hook_secret"
+ mock_creds.token = None
+ mock_hook_cls.return_value.get_credentials.return_value = mock_creds
+
+ provider = S3ObjectStorageProvider()
+ config = ConnectionConfig(conn_id="aws_default")
+
+ store = provider.create_object_store("s3://demo-data/path",
connection_config=config)
+
+ mock_s3.assert_called_once_with(
+ access_key_id="fake_id",
+ secret_access_key="fake_secret",
+ region="us-east-1",
+ bucket_name="demo-data",
+ )
+ assert store == mock_s3.return_value
+ assert provider.get_storage_type == StorageType.S3
+ assert provider.get_scheme() == "s3://"
+
+ @patch(
+ "airflow.providers.amazon.aws.datafusion.object_storage.AmazonS3",
+ autospec=True,
+ )
+ @patch(
+
"airflow.providers.amazon.aws.datafusion.object_storage.AwsGenericHook",
+ autospec=True,
+ )
+ def test_s3_provider_falls_back_to_hook_credentials(self, mock_hook_cls,
mock_s3):
+ """When login/password are empty, hook credentials are used."""
+ from airflow.providers.amazon.aws.datafusion.object_storage import
S3ObjectStorageProvider
+
+ mock_creds = MagicMock()
+ mock_creds.access_key = "hook_key"
+ mock_creds.secret_key = "hook_secret"
+ mock_creds.token = "session_tok"
+ mock_hook_cls.return_value.get_credentials.return_value = mock_creds
+
+ provider = S3ObjectStorageProvider()
+ config = ConnectionConfig(conn_id="aws_no_login")
+
+ with patch(
+
"airflow.providers.amazon.aws.datafusion.object_storage.BaseHook.get_connection",
+ return_value=Connection(
+ conn_id="aws_no_login",
+ conn_type="aws",
+ extra='{"endpoint": "http://localhost:4566"}',
+ ),
+ ):
+ store = provider.create_object_store("s3://bucket/path",
connection_config=config)
+
+ mock_s3.assert_called_once_with(
+ access_key_id="hook_key",
+ secret_access_key="hook_secret",
+ session_token="session_tok",
+ endpoint="http://localhost:4566",
+ bucket_name="bucket",
+ )
+ assert store == mock_s3.return_value
+
+ @patch(
+ "airflow.providers.amazon.aws.datafusion.object_storage.AmazonS3",
+ autospec=True,
+ )
+ @patch(
+
"airflow.providers.amazon.aws.datafusion.object_storage.AwsGenericHook",
+ autospec=True,
+ )
+ def test_s3_provider_session_token(self, mock_hook_cls, mock_s3):
+ """Session token from hook is forwarded when present."""
+ from airflow.providers.amazon.aws.datafusion.object_storage import
S3ObjectStorageProvider
+
+ mock_creds = MagicMock()
+ mock_creds.access_key = "hook_key"
+ mock_creds.secret_key = "hook_secret"
+ mock_creds.token = "my_session_token"
+ mock_hook_cls.return_value.get_credentials.return_value = mock_creds
+
+ provider = S3ObjectStorageProvider()
+ config = ConnectionConfig(conn_id="aws_default")
+
+ store = provider.create_object_store("s3://bucket/path",
connection_config=config)
+
+ call_kwargs = mock_s3.call_args.kwargs
+ assert call_kwargs["session_token"] == "my_session_token"
+ assert store == mock_s3.return_value
+
+ def test_s3_provider_missing_connection_config(self):
+ from airflow.providers.amazon.aws.datafusion.object_storage import
S3ObjectStorageProvider
+
+ provider = S3ObjectStorageProvider()
+ with pytest.raises(ValueError, match="connection_config must be
provided"):
+ provider.create_object_store("s3://bucket/path",
connection_config=None)
+
+ @patch(
+ "airflow.providers.amazon.aws.datafusion.object_storage.AmazonS3",
+ autospec=True,
+ )
+ @patch(
+
"airflow.providers.amazon.aws.datafusion.object_storage.AwsGenericHook",
+ autospec=True,
+ )
+ def test_s3_provider_creation_failure(self, mock_hook_cls, mock_s3):
+ """Internal exceptions are wrapped in ObjectStoreCreationException."""
+ from airflow.providers.amazon.aws.datafusion.object_storage import
S3ObjectStorageProvider
+
+ mock_creds = MagicMock()
+ mock_creds.access_key = "k"
+ mock_creds.secret_key = "s"
+ mock_creds.token = None
+ mock_hook_cls.return_value.get_credentials.return_value = mock_creds
Review Comment:
`mock_creds` is another unspec'd `MagicMock()` instance. Using a `spec` (or
a simple explicit credentials object) would make the test more robust against
accidental attribute/method drift.
##########
providers/amazon/tests/unit/amazon/aws/datafusion/test_object_storage.py:
##########
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.models import Connection
+from airflow.providers.common.sql.config import ConnectionConfig, StorageType
+from airflow.providers.common.sql.datafusion.exceptions import
ObjectStoreCreationException
+
+
+class TestS3ObjectStorageProvider:
+ """Tests for S3ObjectStorageProvider in the amazon provider package."""
+
+ @pytest.fixture(autouse=True)
+ def setup_connections(self, create_connection_without_db):
+ create_connection_without_db(
+ Connection(
+ conn_id="aws_default",
+ conn_type="aws",
+ login="fake_id",
+ password="fake_secret",
+ extra='{"region": "us-east-1"}',
+ )
+ )
+
+ @patch(
+ "airflow.providers.amazon.aws.datafusion.object_storage.AmazonS3",
+ autospec=True,
+ )
+ @patch(
+
"airflow.providers.amazon.aws.datafusion.object_storage.AwsGenericHook",
+ autospec=True,
+ )
+ def test_s3_provider_with_login_password(self, mock_hook_cls, mock_s3):
+ """Login/password on the connection override hook credentials."""
+ from airflow.providers.amazon.aws.datafusion.object_storage import
S3ObjectStorageProvider
+
+ mock_creds = MagicMock()
+ mock_creds.access_key = "hook_key"
+ mock_creds.secret_key = "hook_secret"
+ mock_creds.token = None
+ mock_hook_cls.return_value.get_credentials.return_value = mock_creds
+
+ provider = S3ObjectStorageProvider()
+ config = ConnectionConfig(conn_id="aws_default")
+
+ store = provider.create_object_store("s3://demo-data/path",
connection_config=config)
+
+ mock_s3.assert_called_once_with(
+ access_key_id="fake_id",
+ secret_access_key="fake_secret",
+ region="us-east-1",
+ bucket_name="demo-data",
+ )
+ assert store == mock_s3.return_value
+ assert provider.get_storage_type == StorageType.S3
+ assert provider.get_scheme() == "s3://"
+
+ @patch(
+ "airflow.providers.amazon.aws.datafusion.object_storage.AmazonS3",
+ autospec=True,
+ )
+ @patch(
+
"airflow.providers.amazon.aws.datafusion.object_storage.AwsGenericHook",
+ autospec=True,
+ )
+ def test_s3_provider_falls_back_to_hook_credentials(self, mock_hook_cls,
mock_s3):
+ """When login/password are empty, hook credentials are used."""
+ from airflow.providers.amazon.aws.datafusion.object_storage import
S3ObjectStorageProvider
+
+ mock_creds = MagicMock()
+ mock_creds.access_key = "hook_key"
+ mock_creds.secret_key = "hook_secret"
+ mock_creds.token = "session_tok"
+ mock_hook_cls.return_value.get_credentials.return_value = mock_creds
Review Comment:
`mock_creds` is created as an unspec'd `MagicMock()` here as well; using a
`spec` (or a simple object with explicit attributes) makes the test stricter
and less prone to false positives.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]