This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c93cb32db73 Rework ProvidersManager to separate runtime and
infrastructure focus (#60218)
c93cb32db73 is described below
commit c93cb32db735eb7c16c2ebce044a5795ff9c5ed7
Author: Amogh Desai <[email protected]>
AuthorDate: Mon Jan 19 17:19:21 2026 +0530
Rework ProvidersManager to separate runtime and infrastructure focus
(#60218)
Splitting providers manager: ProvidersManagerRuntime in task-sdk for
runtime resources and ProvidersManager as we know today stays to serve server
components
---
.pre-commit-config.yaml | 1 +
airflow-core/pyproject.toml | 2 +
.../src/airflow/_shared/providers_discovery | 1 +
airflow-core/src/airflow/providers_manager.py | 86 ++-
.../tests/unit/always/test_providers_manager.py | 289 +---------
devel-common/src/tests_common/pytest_plugin.py | 11 +
.../tests/unit/yandex/utils/test_user_agent.py | 11 +-
pyproject.toml | 3 +
shared/providers_discovery/pyproject.toml | 60 ++
.../airflow_shared/providers_discovery/__init__.py | 34 ++
.../providers_discovery/providers_discovery.py | 348 ++++++++++++
.../test_providers_discovery.py | 112 ++++
task-sdk/pyproject.toml | 5 +
.../src/airflow/sdk/_shared/providers_discovery | 1 +
.../src/airflow/sdk/definitions/asset/__init__.py | 5 +-
task-sdk/src/airflow/sdk/definitions/connection.py | 4 +-
.../airflow/sdk/definitions/decorators/__init__.py | 4 +-
task-sdk/src/airflow/sdk/io/fs.py | 4 +-
task-sdk/src/airflow/sdk/plugins_manager.py | 4 +-
.../src/airflow/sdk/providers_manager_runtime.py | 613 +++++++++++++++++++++
.../tests/task_sdk/definitions/test_connection.py | 2 +-
task-sdk/tests/task_sdk/docs/test_public_api.py | 1 +
.../task_sdk/test_providers_manager_runtime.py | 238 ++++++++
23 files changed, 1486 insertions(+), 353 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 87cc31be88a..fcd508b1500 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -546,6 +546,7 @@ repos:
^airflow-core/src/airflow/utils/db\.py$|
^airflow-core/src/airflow/utils/trigger_rule\.py$|
^airflow-core/tests/|
+ ^task-sdk/tests/|
^.*changelog\.(rst|txt)$|
^.*CHANGELOG\.(rst|txt)$|
^chart/values.schema\.json$|
diff --git a/airflow-core/pyproject.toml b/airflow-core/pyproject.toml
index 73ee1cd4508..0eced80cc15 100644
--- a/airflow-core/pyproject.toml
+++ b/airflow-core/pyproject.toml
@@ -240,6 +240,7 @@ exclude = [
"../shared/timezones/src/airflow_shared/timezones" =
"src/airflow/_shared/timezones"
"../shared/listeners/src/airflow_shared/listeners" =
"src/airflow/_shared/listeners"
"../shared/plugins_manager/src/airflow_shared/plugins_manager" =
"src/airflow/_shared/plugins_manager"
+"../shared/providers_discovery/src/airflow_shared/providers_discovery" =
"src/airflow/_shared/providers_discovery"
[tool.hatch.build.targets.custom]
path = "./hatch_build.py"
@@ -317,4 +318,5 @@ shared_distributions = [
"apache-airflow-shared-secrets-masker",
"apache-airflow-shared-timezones",
"apache-airflow-shared-plugins-manager",
+ "apache-airflow-shared-providers-discovery",
]
diff --git a/airflow-core/src/airflow/_shared/providers_discovery
b/airflow-core/src/airflow/_shared/providers_discovery
new file mode 120000
index 00000000000..818cea30f33
--- /dev/null
+++ b/airflow-core/src/airflow/_shared/providers_discovery
@@ -0,0 +1 @@
+../../../../shared/providers_discovery/src/airflow_shared/providers_discovery
\ No newline at end of file
diff --git a/airflow-core/src/airflow/providers_manager.py
b/airflow-core/src/airflow/providers_manager.py
index 5dc70d0fc9e..956f52857bb 100644
--- a/airflow-core/src/airflow/providers_manager.py
+++ b/airflow-core/src/airflow/providers_manager.py
@@ -33,9 +33,9 @@ from importlib.resources import files as resource_files
from time import perf_counter
from typing import TYPE_CHECKING, Any, NamedTuple, ParamSpec, TypeVar, cast
-from packaging.utils import canonicalize_name
-
-from airflow._shared.module_loading import entry_points_with_dist,
import_string
+from airflow import DeprecatedImportWarning
+from airflow._shared.module_loading import import_string
+from airflow._shared.providers_discovery import
discover_all_providers_from_packages
from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -438,6 +438,33 @@ class ProvidersManager(LoggingMixin):
self._plugins_set: set[PluginInfo] = set()
self._init_airflow_core_hooks()
+ self._runtime_manager = None
+
+ def __getattribute__(self, name: str):
+ # Hacky but does the trick for now
+ runtime_properties = {
+ "hooks",
+ "taskflow_decorators",
+ "filesystem_module_names",
+ "asset_factories",
+ "asset_uri_handlers",
+ "asset_to_openlineage_converters",
+ }
+
+ if name in runtime_properties:
+ warnings.warn(
+ f"ProvidersManager.{name} is deprecated. Use
ProvidersManagerTaskRuntime.{name} from task-sdk instead.",
+ DeprecatedImportWarning,
+ stacklevel=2,
+ )
+ if object.__getattribute__(self, "_runtime_manager") is None:
+ from airflow.sdk.providers_manager_runtime import
ProvidersManagerTaskRuntime
+
+ object.__setattr__(self, "_runtime_manager",
ProvidersManagerTaskRuntime())
+ return getattr(object.__getattribute__(self, "_runtime_manager"),
name)
+
+ return object.__getattribute__(self, name)
+
def _init_airflow_core_hooks(self):
"""Initialize the hooks dict with default hooks from Airflow core."""
core_dummy_hooks = {
@@ -472,7 +499,7 @@ class ProvidersManager(LoggingMixin):
# Development purpose. In production provider.yaml files are not
present in the 'airflow" directory
# So there is no risk we are going to override package provider
accidentally. This can only happen
# in case of local development
- self._discover_all_providers_from_packages()
+ discover_all_providers_from_packages(self._provider_dict,
self._provider_schema_validator)
self._verify_all_providers_all_compatible()
self._provider_dict = dict(sorted(self._provider_dict.items()))
@@ -607,57 +634,6 @@ class ProvidersManager(LoggingMixin):
self.initialize_providers_list()
self._discover_cli_command()
- def _discover_all_providers_from_packages(self) -> None:
- """
- Discover all providers by scanning packages installed.
-
- The list of providers should be returned via the
'apache_airflow_provider'
- entrypoint as a dictionary conforming to the
'airflow/provider_info.schema.json'
- schema. Note that the schema is different at runtime than
provider.yaml.schema.json.
- The development version of provider schema is more strict and changes
together with
- the code. The runtime version is more relaxed (allows for additional
properties)
- and verifies only the subset of fields that are needed at runtime.
- """
- for entry_point, dist in
entry_points_with_dist("apache_airflow_provider"):
- if not dist.metadata:
- continue
- package_name = canonicalize_name(dist.metadata["name"])
- if package_name in self._provider_dict:
- continue
- log.debug("Loading %s from package %s", entry_point, package_name)
- version = dist.version
- provider_info = entry_point.load()()
- self._provider_schema_validator.validate(provider_info)
- provider_info_package_name = provider_info["package-name"]
- if package_name != provider_info_package_name:
- raise ValueError(
- f"The package '{package_name}' from packaging information "
- f"{provider_info_package_name} do not match. Please make
sure they are aligned"
- )
-
- # issue-59576: Retrieve the project.urls.documentation from
dist.metadata
- project_urls = dist.metadata.get_all("Project-URL")
- documentation_url: str | None = None
-
- if project_urls:
- for entry in project_urls:
- if "," in entry:
- name, url = entry.split(",")
- if name.strip().lower() == "documentation":
- documentation_url = url
- break
-
- provider_info["documentation-url"] = documentation_url
-
- if package_name not in self._provider_dict:
- self._provider_dict[package_name] = ProviderInfo(version,
provider_info)
- else:
- log.warning(
- "The provider for package '%s' could not be registered
from because providers for that "
- "package name have already been registered",
- package_name,
- )
-
def _discover_hooks_from_connection_types(
self,
hook_class_names_registered: set[str],
diff --git a/airflow-core/tests/unit/always/test_providers_manager.py
b/airflow-core/tests/unit/always/test_providers_manager.py
index 1e579bfdd95..7d7cc0507dd 100644
--- a/airflow-core/tests/unit/always/test_providers_manager.py
+++ b/airflow-core/tests/unit/always/test_providers_manager.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import json
import logging
import re
import sys
@@ -25,23 +24,19 @@ from collections.abc import Callable
from typing import TYPE_CHECKING
PY313 = sys.version_info >= (3, 13)
-import warnings
from unittest.mock import patch
import pytest
-from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers_manager import (
DialectInfo,
- HookClassProvider,
LazyDictWithCache,
PluginInfo,
ProviderInfo,
ProvidersManager,
)
-from tests_common.test_utils.markers import
skip_if_force_lowest_dependencies_marker, skip_if_not_on_main
-from tests_common.test_utils.paths import AIRFLOW_ROOT_PATH
+from tests_common.test_utils.markers import
skip_if_force_lowest_dependencies_marker
if TYPE_CHECKING:
from unittest.mock import MagicMock
@@ -52,11 +47,14 @@ if TYPE_CHECKING:
def test_cleanup_providers_manager(cleanup_providers_manager):
"""Check the cleanup provider manager functionality."""
provider_manager = ProvidersManager()
- assert isinstance(provider_manager.hooks, LazyDictWithCache)
- hooks = provider_manager.hooks
+ assert isinstance(provider_manager.providers, dict)
+ providers = provider_manager.providers
+ assert len(providers) > 0
+
ProvidersManager()._cleanup()
- assert not len(hooks)
- assert ProvidersManager().hooks is hooks
+
+ # even after cleanup the singleton should return same instance but
internal state is reset
+ assert len(ProvidersManager().providers) > 0
@skip_if_force_lowest_dependencies_marker
@@ -98,104 +96,6 @@ class TestProviderManager:
assert len(provider_list) > 65
assert self._caplog.records == []
- def test_hooks_deprecation_warnings_generated(self):
- providers_manager = ProvidersManager()
- providers_manager._provider_dict["test-package"] = ProviderInfo(
- version="0.0.1",
- data={"hook-class-names":
["airflow.providers.sftp.hooks.sftp.SFTPHook"]},
- )
- with pytest.warns(expected_warning=DeprecationWarning,
match="hook-class-names") as warning_records:
- providers_manager._discover_hooks()
- assert warning_records
-
- def test_hooks_deprecation_warnings_not_generated(self):
- with warnings.catch_warnings(record=True) as warning_records:
- providers_manager = ProvidersManager()
- providers_manager._provider_dict["apache-airflow-providers-sftp"]
= ProviderInfo(
- version="0.0.1",
- data={
- "hook-class-names":
["airflow.providers.sftp.hooks.sftp.SFTPHook"],
- "connection-types": [
- {
- "hook-class-name":
"airflow.providers.sftp.hooks.sftp.SFTPHook",
- "connection-type": "sftp",
- }
- ],
- },
- )
- providers_manager._discover_hooks()
- assert [w.message for w in warning_records if "hook-class-names" in
str(w.message)] == []
-
- def test_warning_logs_generated(self):
- providers_manager = ProvidersManager()
- providers_manager._hooks_lazy_dict = LazyDictWithCache()
- with self._caplog.at_level(logging.WARNING):
- providers_manager._provider_dict["apache-airflow-providers-sftp"]
= ProviderInfo(
- version="0.0.1",
- data={
- "hook-class-names":
["airflow.providers.sftp.hooks.sftp.SFTPHook"],
- "connection-types": [
- {
- "hook-class-name":
"airflow.providers.sftp.hooks.sftp.SFTPHook",
- "connection-type": "wrong-connection-type",
- }
- ],
- },
- )
- providers_manager._discover_hooks()
- _ = providers_manager._hooks_lazy_dict["wrong-connection-type"]
- assert len(self._caplog.entries) == 1
- assert "Inconsistency!" in self._caplog[0]["event"]
- assert "sftp" not in providers_manager.hooks
-
- def test_warning_logs_not_generated(self):
- with self._caplog.at_level(logging.WARNING):
- providers_manager = ProvidersManager()
- providers_manager._provider_dict["apache-airflow-providers-sftp"]
= ProviderInfo(
- version="0.0.1",
- data={
- "hook-class-names":
["airflow.providers.sftp.hooks.sftp.SFTPHook"],
- "connection-types": [
- {
- "hook-class-name":
"airflow.providers.sftp.hooks.sftp.SFTPHook",
- "connection-type": "sftp",
- }
- ],
- },
- )
- providers_manager._discover_hooks()
- _ = providers_manager._hooks_lazy_dict["sftp"]
- assert not self._caplog.records
- assert "sftp" in providers_manager.hooks
-
- def test_already_registered_conn_type_in_provide(self):
- with self._caplog.at_level(logging.WARNING):
- providers_manager = ProvidersManager()
- providers_manager._provider_dict["apache-airflow-providers-dummy"]
= ProviderInfo(
- version="0.0.1",
- data={
- "connection-types": [
- {
- "hook-class-name":
"airflow.providers.dummy.hooks.dummy.DummyHook",
- "connection-type": "dummy",
- },
- {
- "hook-class-name":
"airflow.providers.dummy.hooks.dummy.DummyHook2",
- "connection-type": "dummy",
- },
- ],
- },
- )
- providers_manager._discover_hooks()
- _ = providers_manager._hooks_lazy_dict["dummy"]
- assert len(self._caplog.records) == 1
- msg = self._caplog.messages[0]
- assert msg.startswith("The connection type 'dummy' is already
registered")
- assert (
- "different class names:
'airflow.providers.dummy.hooks.dummy.DummyHook'"
- " and 'airflow.providers.dummy.hooks.dummy.DummyHook2'."
- ) in msg
-
def test_providers_manager_register_plugins(self):
providers_manager = ProvidersManager()
providers_manager._provider_dict = LazyDictWithCache()
@@ -243,61 +143,6 @@ class TestProviderManager:
),
)
- def test_hooks(self):
- with warnings.catch_warnings(record=True) as warning_records:
- with self._caplog.at_level(logging.WARNING):
- provider_manager = ProvidersManager()
- connections_list = list(provider_manager.hooks.keys())
- assert len(connections_list) > 60
- if len(self._caplog.records) != 0:
- for record in self._caplog.records:
- print(record.message, file=sys.stderr)
- print(record.exc_info, file=sys.stderr)
- raise AssertionError("There are warnings generated during hook
imports. Please fix them")
- assert [w.message for w in warning_records if "hook-class-names" in
str(w.message)] == []
-
- @skip_if_not_on_main
- @pytest.mark.execution_timeout(150)
- def test_hook_values(self):
- provider_dependencies = json.loads(
- (AIRFLOW_ROOT_PATH / "generated" /
"provider_dependencies.json").read_text()
- )
- python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
- excluded_providers: list[str] = []
- for provider_name, provider_info in provider_dependencies.items():
- if python_version in provider_info.get("excluded-python-versions",
[]):
-
excluded_providers.append(f"apache-airflow-providers-{provider_name.replace('.',
'-')}")
- with warnings.catch_warnings(record=True) as warning_records:
- with self._caplog.at_level(logging.WARNING):
- provider_manager = ProvidersManager()
- connections_list = list(provider_manager.hooks.values())
- assert len(connections_list) > 60
- if len(self._caplog.records) != 0:
- real_warning_count = 0
- for record in self._caplog.entries:
- # When there is error importing provider that is excluded the
provider name is in the message
- if any(excluded_provider in record["event"] for
excluded_provider in excluded_providers):
- continue
- print(record["event"], file=sys.stderr)
- print(record.get("exc_info"), file=sys.stderr)
- real_warning_count += 1
- if real_warning_count:
- if PY313:
- only_ydb_and_yandexcloud_warnings = True
- for record in warning_records:
- if "ydb" in str(record.message) or "yandexcloud" in
str(record.message):
- continue
- only_ydb_and_yandexcloud_warnings = False
- if only_ydb_and_yandexcloud_warnings:
- print(
- "Only warnings from ydb and yandexcloud providers
are generated, "
- "which is expected in Python 3.13+",
- file=sys.stderr,
- )
- return
- raise AssertionError("There are warnings generated during hook
imports. Please fix them")
- assert [w.message for w in warning_records if "hook-class-names" in
str(w.message)] == []
-
def test_connection_form_widgets(self):
provider_manager = ProvidersManager()
connections_form_widgets =
list(provider_manager.connection_form_widgets.keys())
@@ -390,34 +235,6 @@ class TestProviderManager:
assert len(dialect_class_names) == 3
assert dialect_class_names == ["default", "mssql", "postgresql"]
- @patch("airflow.providers_manager.import_string")
- def test_optional_feature_no_warning(self, mock_importlib_import_string):
- with self._caplog.at_level(logging.WARNING):
- mock_importlib_import_string.side_effect =
AirflowOptionalProviderFeatureException()
- providers_manager = ProvidersManager()
- providers_manager._hook_provider_dict["test_connection"] =
HookClassProvider(
- package_name="test_package", hook_class_name="HookClass"
- )
- providers_manager._import_hook(
- hook_class_name=None, provider_info=None, package_name=None,
connection_type="test_connection"
- )
- assert self._caplog.messages == []
-
- @patch("airflow.providers_manager.import_string")
- def test_optional_feature_debug(self, mock_importlib_import_string):
- with self._caplog.at_level(logging.INFO):
- mock_importlib_import_string.side_effect =
AirflowOptionalProviderFeatureException()
- providers_manager = ProvidersManager()
- providers_manager._hook_provider_dict["test_connection"] =
HookClassProvider(
- package_name="test_package", hook_class_name="HookClass"
- )
- providers_manager._import_hook(
- hook_class_name=None, provider_info=None, package_name=None,
connection_type="test_connection"
- )
- assert self._caplog.messages == [
- "Optional provider feature disabled when importing 'HookClass'
from 'test_package' package"
- ]
-
class TestWithoutCheckProviderManager:
@patch("airflow.providers_manager.import_string")
@@ -456,93 +273,3 @@ class TestWithoutCheckProviderManager:
mock_correctness_check.assert_not_called()
assert providers_manager._executor_without_check_set == result
-
-
[email protected](
- ("value", "expected_outputs"),
- [
- ("a", "a"),
- (1, 1),
- (None, None),
- (lambda: 0, 0),
- (lambda: None, None),
- (lambda: "z", "z"),
- ],
-)
-def test_lazy_cache_dict_resolving(value, expected_outputs):
- lazy_cache_dict = LazyDictWithCache()
- lazy_cache_dict["key"] = value
- assert lazy_cache_dict["key"] == expected_outputs
- # Retrieve it again to see if it is correctly returned again
- assert lazy_cache_dict["key"] == expected_outputs
-
-
-def test_lazy_cache_dict_raises_error():
- def raise_method():
- raise RuntimeError("test")
-
- lazy_cache_dict = LazyDictWithCache()
- lazy_cache_dict["key"] = raise_method
- with pytest.raises(RuntimeError, match="test"):
- _ = lazy_cache_dict["key"]
-
-
-def test_lazy_cache_dict_del_item():
- lazy_cache_dict = LazyDictWithCache()
-
- def answer():
- return 42
-
- lazy_cache_dict["spam"] = answer
- assert "spam" in lazy_cache_dict._raw_dict
- assert "spam" not in lazy_cache_dict._resolved # Not resoled yet
- assert lazy_cache_dict["spam"] == 42
- assert "spam" in lazy_cache_dict._resolved
- del lazy_cache_dict["spam"]
- assert "spam" not in lazy_cache_dict._raw_dict
- assert "spam" not in lazy_cache_dict._resolved
-
- lazy_cache_dict["foo"] = answer
- assert lazy_cache_dict["foo"] == 42
- assert "foo" in lazy_cache_dict._resolved
- # Emulate some mess in data, e.g. value from `_raw_dict` deleted but not
from `_resolved`
- del lazy_cache_dict._raw_dict["foo"]
- assert "foo" in lazy_cache_dict._resolved
- with pytest.raises(KeyError):
- # Error expected here, but we still expect to remove also record into
`resolved`
- del lazy_cache_dict["foo"]
- assert "foo" not in lazy_cache_dict._resolved
-
- lazy_cache_dict["baz"] = answer
- # Key in `_resolved` not created yet
- assert "baz" in lazy_cache_dict._raw_dict
- assert "baz" not in lazy_cache_dict._resolved
- del lazy_cache_dict._raw_dict["baz"]
- assert "baz" not in lazy_cache_dict._raw_dict
- assert "baz" not in lazy_cache_dict._resolved
-
-
-def test_lazy_cache_dict_clear():
- def answer():
- return 42
-
- lazy_cache_dict = LazyDictWithCache()
- assert len(lazy_cache_dict) == 0
- lazy_cache_dict["spam"] = answer
- lazy_cache_dict["foo"] = answer
- lazy_cache_dict["baz"] = answer
-
- assert len(lazy_cache_dict) == 3
- assert len(lazy_cache_dict._raw_dict) == 3
- assert not lazy_cache_dict._resolved
- assert lazy_cache_dict["spam"] == 42
- assert len(lazy_cache_dict._resolved) == 1
- # Emulate some mess in data, contain some data into the `_resolved`
- lazy_cache_dict._resolved.add("biz")
- assert len(lazy_cache_dict) == 3
- assert len(lazy_cache_dict._resolved) == 2
- # And finally cleanup everything
- lazy_cache_dict.clear()
- assert len(lazy_cache_dict) == 0
- assert not lazy_cache_dict._raw_dict
- assert not lazy_cache_dict._resolved
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index 9dd6e782193..bc18c94f17b 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -1890,6 +1890,17 @@ def cleanup_providers_manager():
ProvidersManager().initialize_providers_configuration()
[email protected]
+def cleanup_providers_manager_runtime():
+ from airflow.sdk.providers_manager_runtime import
ProvidersManagerTaskRuntime
+
+ ProvidersManagerTaskRuntime()._cleanup()
+ try:
+ yield
+ finally:
+ ProvidersManagerTaskRuntime()._cleanup()
+
+
@pytest.fixture(autouse=True)
def _disable_redact(request: pytest.FixtureRequest, mocker):
"""Disable redacted text in tests, except specific."""
diff --git a/providers/yandex/tests/unit/yandex/utils/test_user_agent.py
b/providers/yandex/tests/unit/yandex/utils/test_user_agent.py
index 0fd20f85d11..b443b77e453 100644
--- a/providers/yandex/tests/unit/yandex/utils/test_user_agent.py
+++ b/providers/yandex/tests/unit/yandex/utils/test_user_agent.py
@@ -47,10 +47,9 @@ def test_provider_user_agent():
assert user_agent_prefix in user_agent
[email protected]("airflow.providers_manager.ProvidersManager.hooks")
-def test_provider_user_agent_hook_not_exists(mock_hooks):
- mock_hooks.return_value = []
+def test_provider_user_agent_hook_not_exists():
+ with mock.patch("airflow.providers_manager.ProvidersManager") as
mock_pm_class:
+ mock_pm_class.return_value.hooks = {}
- user_agent = provider_user_agent()
-
- assert user_agent is None
+ user_agent = provider_user_agent()
+ assert user_agent is None
diff --git a/pyproject.toml b/pyproject.toml
index efa0a1a2a7c..9d8d15bea2e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1296,6 +1296,7 @@ dev = [
"apache-airflow-shared-module-loading",
"apache-airflow-shared-observability",
"apache-airflow-shared-plugins-manager",
+ "apache-airflow-shared-providers-discovery",
"apache-airflow-shared-secrets-backend",
"apache-airflow-shared-secrets-masker",
"apache-airflow-shared-timezones",
@@ -1354,6 +1355,7 @@ apache-airflow-shared-logging = { workspace = true }
apache-airflow-shared-module-loading = { workspace = true }
apache-airflow-shared-observability = { workspace = true }
apache-airflow-shared-plugins-manager = { workspace = true }
+apache-airflow-shared-providers-discovery = { workspace = true }
apache-airflow-shared-secrets-backend = { workspace = true }
apache-airflow-shared-secrets-masker = { workspace = true }
apache-airflow-shared-timezones = { workspace = true }
@@ -1481,6 +1483,7 @@ members = [
"shared/module_loading",
"shared/observability",
"shared/plugins_manager",
+ "shared/providers_discovery",
"shared/secrets_backend",
"shared/secrets_masker",
"shared/timezones",
diff --git a/shared/providers_discovery/pyproject.toml
b/shared/providers_discovery/pyproject.toml
new file mode 100644
index 00000000000..1da7f404529
--- /dev/null
+++ b/shared/providers_discovery/pyproject.toml
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[project]
+name = "apache-airflow-shared-providers-discovery"
+description = "Shared provider discovery code for Airflow distributions"
+version = "0.0"
+classifiers = [
+ "Private :: Do Not Upload",
+]
+
+dependencies = [
+ "packaging",
+ "pendulum>=3.1.0",
+ "jsonschema",
+ "structlog>=25.4.0",
+ "pygtrie>=2.5.0",
+ "methodtools>=0.4.7",
+ 'importlib_metadata>=6.5;python_version<"3.12"',
+ 'importlib_metadata>=7.0;python_version>="3.12"',
+]
+
+[dependency-groups]
+dev = [
+ "apache-airflow-devel-common",
+ "apache-airflow-shared-module-loading",
+]
+
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/airflow_shared"]
+
+[tool.ruff]
+extend = "../../pyproject.toml"
+src = ["src"]
+
+[tool.ruff.lint.per-file-ignores]
+# Ignore Doc rules et al for anything outside of tests
+"!src/*" = ["D", "S101", "TRY002"]
+
+[tool.ruff.lint.flake8-tidy-imports]
+# Override the workspace level default
+ban-relative-imports = "parents"
diff --git
a/shared/providers_discovery/src/airflow_shared/providers_discovery/__init__.py
b/shared/providers_discovery/src/airflow_shared/providers_discovery/__init__.py
new file mode 100644
index 00000000000..a7f18110435
--- /dev/null
+++
b/shared/providers_discovery/src/airflow_shared/providers_discovery/__init__.py
@@ -0,0 +1,34 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from .providers_discovery import (
+ KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS as
KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS,
+ HookClassProvider as HookClassProvider,
+ HookInfo as HookInfo,
+ LazyDictWithCache as LazyDictWithCache,
+ PluginInfo as PluginInfo,
+ ProviderInfo as ProviderInfo,
+ _check_builtin_provider_prefix as _check_builtin_provider_prefix,
+ _create_provider_info_schema_validator as
_create_provider_info_schema_validator,
+ discover_all_providers_from_packages as
discover_all_providers_from_packages,
+ log_import_warning as log_import_warning,
+ log_optional_feature_disabled as log_optional_feature_disabled,
+ provider_info_cache as provider_info_cache,
+)
diff --git
a/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py
b/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py
new file mode 100644
index 00000000000..4fc882d1b5d
--- /dev/null
+++
b/shared/providers_discovery/src/airflow_shared/providers_discovery/providers_discovery.py
@@ -0,0 +1,348 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Shared provider discovery utilities."""
+
+from __future__ import annotations
+
+import contextlib
+import json
+import pathlib
+from collections.abc import Callable, MutableMapping
+from dataclasses import dataclass
+from functools import wraps
+from importlib.resources import files as resource_files
+from time import perf_counter
+from typing import Any, NamedTuple, ParamSpec
+
+import structlog
+from packaging.utils import canonicalize_name
+
+from ..module_loading import entry_points_with_dist
+
+log = structlog.getLogger(__name__)
+
+
+PS = ParamSpec("PS")
+
+
+KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS = [("apache-airflow-providers-google",
"No module named 'paramiko'")]
+
+
+@dataclass
+class ProviderInfo:
+ """
+ Provider information.
+
+ :param version: version string
+ :param data: dictionary with information about the provider
+ """
+
+ version: str
+ data: dict
+
+
+class HookClassProvider(NamedTuple):
+ """Hook class and Provider it comes from."""
+
+ hook_class_name: str
+ package_name: str
+
+
+class HookInfo(NamedTuple):
+ """Hook information."""
+
+ hook_class_name: str
+ connection_id_attribute_name: str
+ package_name: str
+ hook_name: str
+ connection_type: str
+ connection_testable: bool
+ dialects: list[str] = []
+
+
+class ConnectionFormWidgetInfo(NamedTuple):
+ """Connection Form Widget information."""
+
+ hook_class_name: str
+ package_name: str
+ field: Any
+ field_name: str
+ is_sensitive: bool
+
+
+class PluginInfo(NamedTuple):
+ """Plugin class, name and provider it comes from."""
+
+ name: str
+ plugin_class: str
+ provider_name: str
+
+
+class NotificationInfo(NamedTuple):
+ """Notification class and provider it comes from."""
+
+ notification_class_name: str
+ package_name: str
+
+
+class TriggerInfo(NamedTuple):
+ """Trigger class and provider it comes from."""
+
+ trigger_class_name: str
+ package_name: str
+ integration_name: str
+
+
+class DialectInfo(NamedTuple):
+ """Dialect class and Provider it comes from."""
+
+ name: str
+ dialect_class_name: str
+ provider_name: str
+
+
+class LazyDictWithCache(MutableMapping):
+ """
+ Lazy-loaded cached dictionary.
+
+ Dictionary, which in case you set callable, executes the passed callable
with `key` attribute
+ at first use - and returns and caches the result.
+ """
+
+ __slots__ = ["_resolved", "_raw_dict"]
+
+ def __init__(self, *args, **kw):
+ self._resolved = set()
+ self._raw_dict = dict(*args, **kw)
+
+ def __setitem__(self, key, value):
+ self._raw_dict.__setitem__(key, value)
+
+ def __getitem__(self, key):
+ value = self._raw_dict.__getitem__(key)
+ if key not in self._resolved and callable(value):
+ # exchange callable with result of calling it -- but only once!
allow resolver to return a
+ # callable itself
+ value = value()
+ self._resolved.add(key)
+ self._raw_dict.__setitem__(key, value)
+ return value
+
+ def __delitem__(self, key):
+ with contextlib.suppress(KeyError):
+ self._resolved.remove(key)
+ self._raw_dict.__delitem__(key)
+
+ def __iter__(self):
+ return iter(self._raw_dict)
+
+ def __len__(self):
+ return len(self._raw_dict)
+
+ def __contains__(self, key):
+ return key in self._raw_dict
+
+ def clear(self):
+ self._resolved.clear()
+ self._raw_dict.clear()
+
+
+def _read_schema_from_resources_or_local_file(filename: str) -> dict:
+ """Read JSON schema from resources or local file."""
+ try:
+ with resource_files("airflow").joinpath(filename).open("rb") as f:
+ schema = json.load(f)
+ except (TypeError, FileNotFoundError):
+ with (pathlib.Path(__file__).parent / filename).open("rb") as f:
+ schema = json.load(f)
+ return schema
+
+
+def _create_provider_info_schema_validator():
+ """Create JSON schema validator from the provider_info.schema.json."""
+ import jsonschema
+
+ schema =
_read_schema_from_resources_or_local_file("provider_info.schema.json")
+ cls = jsonschema.validators.validator_for(schema)
+ validator = cls(schema)
+ return validator
+
+
+def _create_customized_form_field_behaviours_schema_validator():
+ """Create JSON schema validator from the
customized_form_field_behaviours.schema.json."""
+ import jsonschema
+
+ schema =
_read_schema_from_resources_or_local_file("customized_form_field_behaviours.schema.json")
+ cls = jsonschema.validators.validator_for(schema)
+ validator = cls(schema)
+ return validator
+
+
+def _check_builtin_provider_prefix(provider_package: str, class_name: str) ->
bool:
+ """Check if builtin provider class has correct prefix."""
+ if provider_package.startswith("apache-airflow"):
+ provider_path = provider_package[len("apache-") :].replace("-", ".")
+ if not class_name.startswith(provider_path):
+ log.warning(
+ "Coherence check failed when importing '%s' from '%s' package.
It should start with '%s'",
+ class_name,
+ provider_package,
+ provider_path,
+ )
+ return False
+ return True
+
+
+def _ensure_prefix_for_placeholders(field_behaviors: dict[str, Any],
conn_type: str):
+ """
+ Verify the correct placeholder prefix.
+
+ If the given field_behaviors dict contains a placeholder's node, and there
+ are placeholders for extra fields (i.e. anything other than the built-in
conn
+ attrs), and if those extra fields are unprefixed, then add the prefix.
+
+ The reason we need to do this is, all custom conn fields live in the same
dictionary,
+ so we need to namespace them with a prefix internally. But for user
convenience,
+ and consistency between the `get_ui_field_behaviour` method and the extra
dict itself,
+ we allow users to supply the unprefixed name.
+ """
+ conn_attrs = {"host", "schema", "login", "password", "port", "extra"}
+
+ def ensure_prefix(field):
+ if field not in conn_attrs and not field.startswith("extra__"):
+ return f"extra__{conn_type}__{field}"
+ return field
+
+ if "placeholders" in field_behaviors:
+ placeholders = field_behaviors["placeholders"]
+ field_behaviors["placeholders"] = {ensure_prefix(k): v for k, v in
placeholders.items()}
+
+ return field_behaviors
+
+
+def log_optional_feature_disabled(class_name, e, provider_package):
+ """Log optional feature disabled."""
+ log.debug(
+ "Optional feature disabled on exception when importing '%s' from '%s'
package",
+ class_name,
+ provider_package,
+ exc_info=e,
+ )
+ log.info(
+ "Optional provider feature disabled when importing '%s' from '%s'
package",
+ class_name,
+ provider_package,
+ )
+
+
+def log_import_warning(class_name, e, provider_package):
+ """Log import warning."""
+ log.warning(
+ "Exception when importing '%s' from '%s' package",
+ class_name,
+ provider_package,
+ exc_info=e,
+ )
+
+
+def provider_info_cache(cache_name: str) -> Callable[[Callable[PS, None]],
Callable[PS, None]]:
+ """
+ Decorate and cache provider info.
+
+ Decorator factory that create decorator that caches initialization of
provider's parameters
+ :param cache_name: Name of the cache
+ """
+
+ def provider_info_cache_decorator(func: Callable[PS, None]) ->
Callable[PS, None]:
+ @wraps(func)
+ def wrapped_function(*args: PS.args, **kwargs: PS.kwargs) -> None:
+ instance = args[0]
+
+ if cache_name in instance._initialized_cache:
+ return
+ start_time = perf_counter()
+ log.debug("Initializing Provider Manager[%s]", cache_name)
+ func(*args, **kwargs)
+ instance._initialized_cache[cache_name] = True
+ log.debug(
+ "Initialization of Provider Manager[%s] took %.2f seconds",
+ cache_name,
+ perf_counter() - start_time,
+ )
+
+ return wrapped_function
+
+ return provider_info_cache_decorator
+
+
+def discover_all_providers_from_packages(
+ provider_dict: dict[str, ProviderInfo],
+ provider_schema_validator,
+) -> None:
+ """
+ Discover all providers by scanning packages installed.
+
+ The list of providers should be returned via the 'apache_airflow_provider'
+ entrypoint as a dictionary conforming to the
'airflow/provider_info.schema.json'
+ schema. Note that the schema is different at runtime than
provider.yaml.schema.json.
+ The development version of provider schema is more strict and changes
together with
+ the code. The runtime version is more relaxed (allows for additional
properties)
+ and verifies only the subset of fields that are needed at runtime.
+
+ :param provider_dict: Dictionary to populate with discovered providers
+ :param provider_schema_validator: JSON schema validator for provider info
+ """
+ for entry_point, dist in entry_points_with_dist("apache_airflow_provider"):
+ if not dist.metadata:
+ continue
+ package_name = canonicalize_name(dist.metadata["name"])
+ if package_name in provider_dict:
+ continue
+ log.debug("Loading %s from package %s", entry_point, package_name)
+ version = dist.version
+ provider_info = entry_point.load()()
+ provider_schema_validator.validate(provider_info)
+ provider_info_package_name = provider_info["package-name"]
+ if package_name != provider_info_package_name:
+ raise ValueError(
+ f"The package '{package_name}' from packaging information "
+ f"{provider_info_package_name} do not match. Please make sure
they are aligned"
+ )
+
+ # issue-59576: Retrieve the project.urls.documentation from
dist.metadata
+ project_urls = dist.metadata.get_all("Project-URL")
+ documentation_url: str | None = None
+
+ if project_urls:
+ for entry in project_urls:
+ if "," in entry:
+ name, url = entry.split(",")
+ if name.strip().lower() == "documentation":
+ documentation_url = url
+ break
+
+ provider_info["documentation-url"] = documentation_url
+
+ if package_name not in provider_dict:
+ provider_dict[package_name] = ProviderInfo(version, provider_info)
+ else:
+ log.warning(
+ "The provider for package '%s' could not be registered from
because providers for that "
+ "package name have already been registered",
+ package_name,
+ )
diff --git
a/shared/providers_discovery/tests/providers_discovery/test_providers_discovery.py
b/shared/providers_discovery/tests/providers_discovery/test_providers_discovery.py
new file mode 100644
index 00000000000..87f97a3e5a2
--- /dev/null
+++
b/shared/providers_discovery/tests/providers_discovery/test_providers_discovery.py
@@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow_shared.providers_discovery import LazyDictWithCache
+
+
[email protected](
+ ("value", "expected_outputs"),
+ [
+ ("a", "a"),
+ (1, 1),
+ (None, None),
+ (lambda: 0, 0),
+ (lambda: None, None),
+ (lambda: "z", "z"),
+ ],
+)
+def test_lazy_cache_dict_resolving(value, expected_outputs):
+ lazy_cache_dict = LazyDictWithCache()
+ lazy_cache_dict["key"] = value
+ assert lazy_cache_dict["key"] == expected_outputs
+ # Retrieve it again to see if it is correctly returned again
+ assert lazy_cache_dict["key"] == expected_outputs
+
+
+def test_lazy_cache_dict_raises_error():
+ def raise_method():
+ raise RuntimeError("test")
+
+ lazy_cache_dict = LazyDictWithCache()
+ lazy_cache_dict["key"] = raise_method
+ with pytest.raises(RuntimeError, match="test"):
+ _ = lazy_cache_dict["key"]
+
+
+def test_lazy_cache_dict_del_item():
+ lazy_cache_dict = LazyDictWithCache()
+
+ def answer():
+ return 42
+
+ lazy_cache_dict["spam"] = answer
+ assert "spam" in lazy_cache_dict._raw_dict
+ assert "spam" not in lazy_cache_dict._resolved # Not resoled yet
+ assert lazy_cache_dict["spam"] == 42
+ assert "spam" in lazy_cache_dict._resolved
+ del lazy_cache_dict["spam"]
+ assert "spam" not in lazy_cache_dict._raw_dict
+ assert "spam" not in lazy_cache_dict._resolved
+
+ lazy_cache_dict["foo"] = answer
+ assert lazy_cache_dict["foo"] == 42
+ assert "foo" in lazy_cache_dict._resolved
+ # Emulate some mess in data, e.g. value from `_raw_dict` deleted but not
from `_resolved`
+ del lazy_cache_dict._raw_dict["foo"]
+ assert "foo" in lazy_cache_dict._resolved
+ with pytest.raises(KeyError):
+ # Error expected here, but we still expect to remove also record into
`resolved`
+ del lazy_cache_dict["foo"]
+ assert "foo" not in lazy_cache_dict._resolved
+
+ lazy_cache_dict["baz"] = answer
+ # Key in `_resolved` not created yet
+ assert "baz" in lazy_cache_dict._raw_dict
+ assert "baz" not in lazy_cache_dict._resolved
+ del lazy_cache_dict._raw_dict["baz"]
+ assert "baz" not in lazy_cache_dict._raw_dict
+ assert "baz" not in lazy_cache_dict._resolved
+
+
+def test_lazy_cache_dict_clear():
+ def answer():
+ return 42
+
+ lazy_cache_dict = LazyDictWithCache()
+ assert len(lazy_cache_dict) == 0
+ lazy_cache_dict["spam"] = answer
+ lazy_cache_dict["foo"] = answer
+ lazy_cache_dict["baz"] = answer
+
+ assert len(lazy_cache_dict) == 3
+ assert len(lazy_cache_dict._raw_dict) == 3
+ assert not lazy_cache_dict._resolved
+ assert lazy_cache_dict["spam"] == 42
+ assert len(lazy_cache_dict._resolved) == 1
+ # Emulate some mess in data, contain some data into the `_resolved`
+ lazy_cache_dict._resolved.add("biz")
+ assert len(lazy_cache_dict) == 3
+ assert len(lazy_cache_dict._resolved) == 2
+ # And finally cleanup everything
+ lazy_cache_dict.clear()
+ assert len(lazy_cache_dict) == 0
+ assert not lazy_cache_dict._raw_dict
+ assert not lazy_cache_dict._resolved
diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml
index fc989724391..350f901ee17 100644
--- a/task-sdk/pyproject.toml
+++ b/task-sdk/pyproject.toml
@@ -84,6 +84,9 @@ dependencies = [
'importlib_metadata>=6.5;python_version<"3.12"',
"pathspec>=0.9.0",
# End of shared module-loading dependencies
+ # Start of shared providers-discovery dependencies
+ "jsonschema",
+ # End of shared providers-discovery dependencies
]
[project.optional-dependencies]
@@ -132,6 +135,7 @@ path = "src/airflow/sdk/__init__.py"
"../shared/timezones/src/airflow_shared/timezones" =
"src/airflow/sdk/_shared/timezones"
"../shared/listeners/src/airflow_shared/listeners" =
"src/airflow/sdk/_shared/listeners"
"../shared/plugins_manager/src/airflow_shared/plugins_manager" =
"src/airflow/sdk/_shared/plugins_manager"
+"../shared/providers_discovery/src/airflow_shared/providers_discovery" =
"src/airflow/sdk/_shared/providers_discovery"
[tool.hatch.build.targets.wheel]
packages = ["src/airflow"]
@@ -283,4 +287,5 @@ shared_distributions = [
"apache-airflow-shared-timezones",
"apache-airflow-shared-observability",
"apache-airflow-shared-plugins-manager",
+ "apache-airflow-shared-providers-discovery",
]
diff --git a/task-sdk/src/airflow/sdk/_shared/providers_discovery
b/task-sdk/src/airflow/sdk/_shared/providers_discovery
new file mode 120000
index 00000000000..b66ada0d22b
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/_shared/providers_discovery
@@ -0,0 +1 @@
+../../../../../shared/providers_discovery/src/airflow_shared/providers_discovery
\ No newline at end of file
diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py
b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py
index c4fdccc8e1b..29eabafaf04 100644
--- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py
+++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py
@@ -27,6 +27,8 @@ from typing import TYPE_CHECKING, Any, ClassVar, overload
import attrs
+from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime
+
if TYPE_CHECKING:
from collections.abc import Collection
from urllib.parse import SplitResult
@@ -128,9 +130,8 @@ def normalize_noop(parts: SplitResult) -> SplitResult:
def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] |
None:
if scheme == "file":
return normalize_noop
- from airflow.providers_manager import ProvidersManager
- return ProvidersManager().asset_uri_handlers.get(scheme)
+ return ProvidersManagerTaskRuntime().asset_uri_handlers.get(scheme)
def _get_normalized_scheme(uri: str) -> str:
diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py
b/task-sdk/src/airflow/sdk/definitions/connection.py
index bcf7937f034..9ea239d0f0f 100644
--- a/task-sdk/src/airflow/sdk/definitions/connection.py
+++ b/task-sdk/src/airflow/sdk/definitions/connection.py
@@ -26,6 +26,7 @@ from urllib.parse import parse_qsl, quote, unquote,
urlencode, urlsplit
import attrs
from airflow.sdk.exceptions import AirflowException, AirflowNotFoundException,
AirflowRuntimeError, ErrorType
+from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime
log = logging.getLogger(__name__)
@@ -188,10 +189,9 @@ class Connection:
def get_hook(self, *, hook_params=None):
"""Return hook based on conn_type."""
- from airflow.providers_manager import ProvidersManager
from airflow.sdk._shared.module_loading import import_string
- hook = ProvidersManager().hooks.get(self.conn_type, None)
+ hook = ProvidersManagerTaskRuntime().hooks.get(self.conn_type, None)
if hook is None:
raise AirflowException(f'Unknown hook type "{self.conn_type}"')
diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py
b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py
index 41a7c2d0bf2..7a4d3125e5f 100644
--- a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py
+++ b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.py
@@ -18,12 +18,12 @@ from __future__ import annotations
from collections.abc import Callable
-from airflow.providers_manager import ProvidersManager
from airflow.sdk.bases.decorator import TaskDecorator
from airflow.sdk.definitions.dag import dag
from airflow.sdk.definitions.decorators.condition import run_if, skip_if
from airflow.sdk.definitions.decorators.setup_teardown import setup_task,
teardown_task
from airflow.sdk.definitions.decorators.task_group import task_group
+from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime
# Please keep this in sync with the .pyi's __all__.
__all__ = [
@@ -47,7 +47,7 @@ class TaskDecoratorCollection:
"""Dynamically get provider-registered task decorators, e.g.
``@task.docker``."""
if name.startswith("__"):
raise AttributeError(f"{type(self).__name__} has no attribute
{name!r}")
- decorators = ProvidersManager().taskflow_decorators
+ decorators = ProvidersManagerTaskRuntime().taskflow_decorators
if name not in decorators:
raise AttributeError(f"task decorator {name!r} not found")
return decorators[name]
diff --git a/task-sdk/src/airflow/sdk/io/fs.py
b/task-sdk/src/airflow/sdk/io/fs.py
index 524a6f767b2..b51be36d48a 100644
--- a/task-sdk/src/airflow/sdk/io/fs.py
+++ b/task-sdk/src/airflow/sdk/io/fs.py
@@ -24,9 +24,9 @@ from typing import TYPE_CHECKING
from fsspec.implementations.local import LocalFileSystem
-from airflow.providers_manager import ProvidersManager
from airflow.sdk._shared.module_loading import import_string
from airflow.sdk._shared.observability.metrics.stats import Stats
+from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime
if TYPE_CHECKING:
from fsspec import AbstractFileSystem
@@ -55,7 +55,7 @@ def _register_filesystems() -> Mapping[
]:
scheme_to_fs = _BUILTIN_SCHEME_TO_FS.copy()
with Stats.timer("airflow.io.load_filesystems") as timer:
- manager = ProvidersManager()
+ manager = ProvidersManagerTaskRuntime()
for fs_module_name in manager.filesystem_module_names:
fs_module = import_string(fs_module_name)
for scheme in getattr(fs_module, "schemes", []):
diff --git a/task-sdk/src/airflow/sdk/plugins_manager.py
b/task-sdk/src/airflow/sdk/plugins_manager.py
index 603c0d23d0b..bdb9fd9ec0e 100644
--- a/task-sdk/src/airflow/sdk/plugins_manager.py
+++ b/task-sdk/src/airflow/sdk/plugins_manager.py
@@ -24,7 +24,6 @@ from functools import cache
from typing import TYPE_CHECKING
from airflow import settings
-from airflow.providers_manager import ProvidersManager
from airflow.sdk._shared.module_loading import import_string
from airflow.sdk._shared.observability.metrics.stats import Stats
from airflow.sdk._shared.plugins_manager import (
@@ -36,6 +35,7 @@ from airflow.sdk._shared.plugins_manager import (
is_valid_plugin,
)
from airflow.sdk.configuration import conf
+from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime
if TYPE_CHECKING:
from airflow.listeners.listener import ListenerManager
@@ -46,7 +46,7 @@ log = logging.getLogger(__name__)
def _load_providers_plugins() -> tuple[list[AirflowPlugin], dict[str, str]]:
"""Load plugins from providers."""
log.debug("Loading plugins from providers")
- providers_manager = ProvidersManager()
+ providers_manager = ProvidersManagerTaskRuntime()
providers_manager.initialize_providers_plugins()
plugins: list[AirflowPlugin] = []
diff --git a/task-sdk/src/airflow/sdk/providers_manager_runtime.py
b/task-sdk/src/airflow/sdk/providers_manager_runtime.py
new file mode 100644
index 00000000000..e7b1b65e2bf
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/providers_manager_runtime.py
@@ -0,0 +1,613 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Manages runtime provider resources for task execution."""
+
+from __future__ import annotations
+
+import functools
+import inspect
+import traceback
+import warnings
+from collections.abc import Callable, MutableMapping
+from typing import TYPE_CHECKING, Any
+from urllib.parse import SplitResult
+
+import structlog
+
+from airflow.sdk._shared.module_loading import import_string
+from airflow.sdk._shared.providers_discovery import (
+ KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS,
+ HookClassProvider,
+ HookInfo,
+ LazyDictWithCache,
+ PluginInfo,
+ ProviderInfo,
+ _check_builtin_provider_prefix,
+ _create_provider_info_schema_validator,
+ discover_all_providers_from_packages,
+ log_import_warning,
+ log_optional_feature_disabled,
+ provider_info_cache,
+)
+from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin
+from airflow.sdk.exceptions import AirflowOptionalProviderFeatureException
+
+if TYPE_CHECKING:
+ from airflow.sdk import BaseHook
+ from airflow.sdk.bases.decorator import TaskDecorator
+ from airflow.sdk.definitions.asset import Asset
+
+log = structlog.getLogger(__name__)
+
+
+def _correctness_check(provider_package: str, class_name: str, provider_info:
ProviderInfo) -> Any:
+ """
+ Perform coherence check on provider classes.
+
+ For apache-airflow providers - it checks if it starts with appropriate
package. For all providers
+ it tries to import the provider - checking that there are no exceptions
during importing.
+ It logs appropriate warning in case it detects any problems.
+
+ :param provider_package: name of the provider package
+ :param class_name: name of the class to import
+
+ :return the class if the class is OK, None otherwise.
+ """
+ if not _check_builtin_provider_prefix(provider_package, class_name):
+ return None
+ try:
+ imported_class = import_string(class_name)
+ except AirflowOptionalProviderFeatureException as e:
+ # When the provider class raises
AirflowOptionalProviderFeatureException
+ # this is an expected case when only some classes in provider are
+ # available. We just log debug level here and print info message in
logs so that
+ # the user is aware of it
+ log_optional_feature_disabled(class_name, e, provider_package)
+ return None
+ except ImportError as e:
+ if "No module named 'airflow.providers." in e.msg:
+ # handle cases where another provider is missing. This can only
happen if
+ # there is an optional feature, so we log debug and print
information about it
+ log_optional_feature_disabled(class_name, e, provider_package)
+ return None
+ for known_error in KNOWN_UNHANDLED_OPTIONAL_FEATURE_ERRORS:
+ # Until we convert all providers to use
AirflowOptionalProviderFeatureException
+ # we assume any problem with importing another "provider" is
because this is an
+ # optional feature, so we log debug and print information about it
+ if known_error[0] == provider_package and known_error[1] in e.msg:
+ log_optional_feature_disabled(class_name, e, provider_package)
+ return None
+ # But when we have no idea - we print warning to logs
+ log_import_warning(class_name, e, provider_package)
+ return None
+ except Exception as e:
+ log_import_warning(class_name, e, provider_package)
+ return None
+ return imported_class
+
+
+class ProvidersManagerTaskRuntime(LoggingMixin):
+ """
+ Manages runtime provider resources for task execution.
+
+ This is a Singleton class. The first time it is instantiated, it discovers
all available
+ runtime provider resources (hooks, taskflow decorators, filesystems, asset
handlers).
+ """
+
+ resource_version = "0"
+ _initialized: bool = False
+ _initialization_stack_trace = None
+ _instance: ProvidersManagerTaskRuntime | None = None
+
+ def __new__(cls):
+ if cls._instance is None:
+ cls._instance = super().__new__(cls)
+ return cls._instance
+
+ @staticmethod
+ def initialized() -> bool:
+ return ProvidersManagerTaskRuntime._initialized
+
+ @staticmethod
+ def initialization_stack_trace() -> str | None:
+ return ProvidersManagerTaskRuntime._initialization_stack_trace
+
+ def __init__(self):
+ """Initialize the runtime manager."""
+ # skip initialization if already initialized
+ if self.initialized():
+ return
+ super().__init__()
+ ProvidersManagerTaskRuntime._initialized = True
+ ProvidersManagerTaskRuntime._initialization_stack_trace = "".join(
+ traceback.format_stack(inspect.currentframe())
+ )
+ self._initialized_cache: dict[str, bool] = {}
+ # Keeps dict of providers keyed by module name
+ self._provider_dict: dict[str, ProviderInfo] = {}
+ self._fs_set: set[str] = set()
+ self._asset_uri_handlers: dict[str, Callable[[SplitResult],
SplitResult]] = {}
+ self._asset_factories: dict[str, Callable[..., Asset]] = {}
+ self._asset_to_openlineage_converters: dict[str, Callable] = {}
+ self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache()
+ # keeps mapping between connection_types and hook class, package they
come from
+ self._hook_provider_dict: dict[str, HookClassProvider] = {}
+ # Keeps dict of hooks keyed by connection type. They are lazy
evaluated at access time
+ self._hooks_lazy_dict: LazyDictWithCache[str, HookInfo | Callable] =
LazyDictWithCache()
+ self._plugins_set: set[PluginInfo] = set()
+ self._provider_schema_validator =
_create_provider_info_schema_validator()
+ self._init_airflow_core_hooks()
+
+ def _init_airflow_core_hooks(self):
+ """Initialize the hooks dict with default hooks from Airflow core."""
+ core_dummy_hooks = {
+ "generic": "Generic",
+ "email": "Email",
+ }
+ for key, display in core_dummy_hooks.items():
+ self._hooks_lazy_dict[key] = HookInfo(
+ hook_class_name=None,
+ connection_id_attribute_name=None,
+ package_name=None,
+ hook_name=display,
+ connection_type=None,
+ connection_testable=False,
+ )
+ for conn_type, class_name in (
+ ("fs", "airflow.providers.standard.hooks.filesystem.FSHook"),
+ ("package_index",
"airflow.providers.standard.hooks.package_index.PackageIndexHook"),
+ ):
+ self._hooks_lazy_dict[conn_type] = functools.partial(
+ self._import_hook,
+ connection_type=None,
+ package_name="apache-airflow-providers-standard",
+ hook_class_name=class_name,
+ provider_info=None,
+ )
+
+ @provider_info_cache("list")
+ def initialize_providers_list(self):
+ """Lazy initialization of providers list."""
+ discover_all_providers_from_packages(self._provider_dict,
self._provider_schema_validator)
+ self._provider_dict = dict(sorted(self._provider_dict.items()))
+
+ @provider_info_cache("hooks")
+ def initialize_providers_hooks(self):
+ """Lazy initialization of providers hooks."""
+ self._init_airflow_core_hooks()
+ self.initialize_providers_list()
+ self._discover_hooks()
+ self._hook_provider_dict =
dict(sorted(self._hook_provider_dict.items()))
+
+ @provider_info_cache("filesystems")
+ def initialize_providers_filesystems(self):
+ """Lazy initialization of providers filesystems."""
+ self.initialize_providers_list()
+ self._discover_filesystems()
+
+ @provider_info_cache("asset_uris")
+ def initialize_providers_asset_uri_resources(self):
+ """Lazy initialization of provider asset URI handlers, factories,
converters etc."""
+ self.initialize_providers_list()
+ self._discover_asset_uri_resources()
+
+ @provider_info_cache("plugins")
+ def initialize_providers_plugins(self):
+ """Lazy initialization of providers plugins."""
+ self.initialize_providers_list()
+ self._discover_plugins()
+
+ @provider_info_cache("taskflow_decorators")
+ def initialize_providers_taskflow_decorator(self):
+ """Lazy initialization of providers taskflow decorators."""
+ self.initialize_providers_list()
+ self._discover_taskflow_decorators()
+
+ def _discover_hooks_from_connection_types(
+ self,
+ hook_class_names_registered: set[str],
+ already_registered_warning_connection_types: set[str],
+ package_name: str,
+ provider: ProviderInfo,
+ ):
+ """
+ Discover hooks from the "connection-types" property.
+
+ This is new, better method that replaces discovery from
hook-class-names as it
+ allows to lazy import individual Hook classes when they are accessed.
+ The "connection-types" keeps information about both - connection type
and class
+ name so we can discover all connection-types without importing the
classes.
+ :param hook_class_names_registered: set of registered hook class names
for this provider
+ :param already_registered_warning_connection_types: set of connections
for which warning should be
+ printed in logs as they were already registered before
+ :param package_name:
+ :param provider:
+ :return:
+ """
+ provider_uses_connection_types = False
+ connection_types = provider.data.get("connection-types")
+ if connection_types:
+ for connection_type_dict in connection_types:
+ connection_type = connection_type_dict["connection-type"]
+ hook_class_name = connection_type_dict["hook-class-name"]
+ hook_class_names_registered.add(hook_class_name)
+ already_registered =
self._hook_provider_dict.get(connection_type)
+ if already_registered:
+ if already_registered.package_name != package_name:
+
already_registered_warning_connection_types.add(connection_type)
+ else:
+ log.warning(
+ "The connection type '%s' is already registered in
the"
+ " package '%s' with different class names: '%s'
and '%s'. ",
+ connection_type,
+ package_name,
+ already_registered.hook_class_name,
+ hook_class_name,
+ )
+ else:
+ self._hook_provider_dict[connection_type] =
HookClassProvider(
+ hook_class_name=hook_class_name,
package_name=package_name
+ )
+ # Defer importing hook to access time by setting import
hook method as dict value
+ self._hooks_lazy_dict[connection_type] = functools.partial(
+ self._import_hook,
+ connection_type=connection_type,
+ provider_info=provider,
+ )
+ provider_uses_connection_types = True
+ return provider_uses_connection_types
+
+ def _discover_hooks_from_hook_class_names(
+ self,
+ hook_class_names_registered: set[str],
+ already_registered_warning_connection_types: set[str],
+ package_name: str,
+ provider: ProviderInfo,
+ provider_uses_connection_types: bool,
+ ):
+ """
+ Discover hooks from "hook-class-names' property.
+
+ This property is deprecated but we should support it in Airflow 2.
+ The hook-class-names array contained just Hook names without
connection type,
+ therefore we need to import all those classes immediately to know
which connection types
+ are supported. This makes it impossible to selectively only import
those hooks that are used.
+ :param already_registered_warning_connection_types: list of connection
hooks that we should warn
+ about when finished discovery
+ :param package_name: name of the provider package
+ :param provider: class that keeps information about version and
details of the provider
+ :param provider_uses_connection_types: determines whether the provider
uses "connection-types" new
+ form of passing connection types
+ :return:
+ """
+ hook_class_names = provider.data.get("hook-class-names")
+ if hook_class_names:
+ for hook_class_name in hook_class_names:
+ if hook_class_name in hook_class_names_registered:
+ # Silently ignore the hook class - it's already marked for
lazy-import by
+ # connection-types discovery
+ continue
+ hook_info = self._import_hook(
+ connection_type=None,
+ provider_info=provider,
+ hook_class_name=hook_class_name,
+ package_name=package_name,
+ )
+ if not hook_info:
+ # Problem why importing class - we ignore it. Log is
written at import time
+ continue
+ already_registered =
self._hook_provider_dict.get(hook_info.connection_type)
+ if already_registered:
+ if already_registered.package_name != package_name:
+
already_registered_warning_connection_types.add(hook_info.connection_type)
+ else:
+ if already_registered.hook_class_name !=
hook_class_name:
+ log.warning(
+ "The hook connection type '%s' is registered
twice in the"
+ " package '%s' with different class names:
'%s' and '%s'. "
+ " Please fix it!",
+ hook_info.connection_type,
+ package_name,
+ already_registered.hook_class_name,
+ hook_class_name,
+ )
+ else:
+ self._hook_provider_dict[hook_info.connection_type] =
HookClassProvider(
+ hook_class_name=hook_class_name,
package_name=package_name
+ )
+ self._hooks_lazy_dict[hook_info.connection_type] =
hook_info
+
+ if not provider_uses_connection_types:
+ warnings.warn(
+ f"The provider {package_name} uses `hook-class-names` "
+ "property in provider-info and has no `connection-types`
one. "
+ "The 'hook-class-names' property has been deprecated in
favour "
+ "of 'connection-types' in Airflow 2.2. Use **both** in
case you want to "
+ "have backwards compatibility with Airflow < 2.2",
+ DeprecationWarning,
+ stacklevel=1,
+ )
+ for already_registered_connection_type in
already_registered_warning_connection_types:
+ log.warning(
+ "The connection_type '%s' has been already registered by
provider '%s.'",
+ already_registered_connection_type,
+
self._hook_provider_dict[already_registered_connection_type].package_name,
+ )
+
+ def _discover_hooks(self) -> None:
+ """Retrieve all connections defined in the providers via Hooks."""
+ for package_name, provider in self._provider_dict.items():
+ duplicated_connection_types: set[str] = set()
+ hook_class_names_registered: set[str] = set()
+ provider_uses_connection_types =
self._discover_hooks_from_connection_types(
+ hook_class_names_registered, duplicated_connection_types,
package_name, provider
+ )
+ self._discover_hooks_from_hook_class_names(
+ hook_class_names_registered,
+ duplicated_connection_types,
+ package_name,
+ provider,
+ provider_uses_connection_types,
+ )
+ self._hook_provider_dict =
dict(sorted(self._hook_provider_dict.items()))
+
+ @staticmethod
+ def _get_attr(obj: Any, attr_name: str):
+ """Retrieve attributes of an object, or warn if not found."""
+ if not hasattr(obj, attr_name):
+ log.warning("The object '%s' is missing %s attribute and cannot be
registered", obj, attr_name)
+ return None
+ return getattr(obj, attr_name)
+
+ def _import_hook(
+ self,
+ connection_type: str | None,
+ provider_info: ProviderInfo,
+ hook_class_name: str | None = None,
+ package_name: str | None = None,
+ ) -> HookInfo | None:
+ """
+ Import hook and retrieve hook information.
+
+ Either connection_type (for lazy loading) or hook_class_name must be
set - but not both).
+ Only needs package_name if hook_class_name is passed (for lazy
loading, package_name
+ is retrieved from _connection_type_class_provider_dict together with
hook_class_name).
+
+ :param connection_type: type of the connection
+ :param hook_class_name: name of the hook class
+ :param package_name: provider package - only needed in case
connection_type is missing
+ : return
+ """
+ if connection_type is None and hook_class_name is None:
+ raise ValueError("Either connection_type or hook_class_name must
be set")
+ if connection_type is not None and hook_class_name is not None:
+ raise ValueError(
+ f"Both connection_type ({connection_type} and "
+ f"hook_class_name {hook_class_name} are set. Only one should
be set!"
+ )
+ if connection_type is not None:
+ class_provider = self._hook_provider_dict[connection_type]
+ package_name = class_provider.package_name
+ hook_class_name = class_provider.hook_class_name
+ else:
+ if not hook_class_name:
+ raise ValueError("Either connection_type or hook_class_name
must be set")
+ if not package_name:
+ raise ValueError(
+ f"Provider package name is not set when hook_class_name
({hook_class_name}) is used"
+ )
+ hook_class: type[BaseHook] | None = _correctness_check(package_name,
hook_class_name, provider_info)
+ if hook_class is None:
+ return None
+
+ hook_connection_type = self._get_attr(hook_class, "conn_type")
+ if connection_type:
+ if hook_connection_type != connection_type:
+ log.warning(
+ "Inconsistency! The hook class '%s' declares connection
type '%s'"
+ " but it is added by provider '%s' as connection_type '%s'
in provider info. "
+ "This should be fixed!",
+ hook_class,
+ hook_connection_type,
+ package_name,
+ connection_type,
+ )
+ connection_type = hook_connection_type
+ connection_id_attribute_name: str = self._get_attr(hook_class,
"conn_name_attr")
+ hook_name: str = self._get_attr(hook_class, "hook_name")
+
+ if not connection_type or not connection_id_attribute_name or not
hook_name:
+ log.warning(
+ "The hook misses one of the key attributes: "
+ "conn_type: %s, conn_id_attribute_name: %s, hook_name: %s",
+ connection_type,
+ connection_id_attribute_name,
+ hook_name,
+ )
+ return None
+
+ return HookInfo(
+ hook_class_name=hook_class_name,
+ connection_id_attribute_name=connection_id_attribute_name,
+ package_name=package_name,
+ hook_name=hook_name,
+ connection_type=connection_type,
+ connection_testable=hasattr(hook_class, "test_connection"),
+ )
+
+ def _discover_filesystems(self) -> None:
+ """Retrieve all filesystems defined in the providers."""
+ for provider_package, provider in self._provider_dict.items():
+ for fs_module_name in provider.data.get("filesystems", []):
+ if _correctness_check(provider_package,
f"{fs_module_name}.get_fs", provider):
+ self._fs_set.add(fs_module_name)
+ self._fs_set = set(sorted(self._fs_set))
+
+ def _discover_asset_uri_resources(self) -> None:
+ """Discovers and registers asset URI handlers, factories, and
converters for all providers."""
+ from airflow.sdk.definitions.asset import normalize_noop
+
+ def _safe_register_resource(
+ provider_package_name: str,
+ schemes_list: list[str],
+ resource_path: str | None,
+ resource_registry: dict,
+ default_resource: Any = None,
+ ):
+ """
+ Register a specific resource (handler, factory, or converter) for
the given schemes.
+
+ If the resolved resource (either from the path or the default) is
valid, it updates
+ the resource registry with the appropriate resource for each
scheme.
+ """
+ resource = (
+ _correctness_check(provider_package_name, resource_path,
provider)
+ if resource_path is not None
+ else default_resource
+ )
+ if resource:
+ resource_registry.update((scheme, resource) for scheme in
schemes_list)
+
+ for provider_name, provider in self._provider_dict.items():
+ for uri_info in provider.data.get("asset-uris", []):
+ if "schemes" not in uri_info or "handler" not in uri_info:
+ continue # Both schemas and handler must be explicitly
set, handler can be set to null
+ common_args = {"schemes_list": uri_info["schemes"],
"provider_package_name": provider_name}
+ _safe_register_resource(
+ resource_path=uri_info["handler"],
+ resource_registry=self._asset_uri_handlers,
+ default_resource=normalize_noop,
+ **common_args,
+ )
+ _safe_register_resource(
+ resource_path=uri_info.get("factory"),
+ resource_registry=self._asset_factories,
+ **common_args,
+ )
+ _safe_register_resource(
+ resource_path=uri_info.get("to_openlineage_converter"),
+ resource_registry=self._asset_to_openlineage_converters,
+ **common_args,
+ )
+
+ def _discover_plugins(self) -> None:
+ """Retrieve all plugins defined in the providers."""
+ for provider_package, provider in self._provider_dict.items():
+ for plugin_dict in provider.data.get("plugins", ()):
+ if not _correctness_check(provider_package,
plugin_dict["plugin-class"], provider):
+ log.warning("Plugin not loaded due to above correctness
check problem.")
+ continue
+ self._plugins_set.add(
+ PluginInfo(
+ name=plugin_dict["name"],
+ plugin_class=plugin_dict["plugin-class"],
+ provider_name=provider_package,
+ )
+ )
+
+ def _discover_taskflow_decorators(self) -> None:
+ for name, info in self._provider_dict.items():
+ for taskflow_decorator in info.data.get("task-decorators", []):
+ self._add_taskflow_decorator(
+ taskflow_decorator["name"],
taskflow_decorator["class-name"], name
+ )
+
+ def _add_taskflow_decorator(self, name, decorator_class_name: str,
provider_package: str) -> None:
+ if not _check_builtin_provider_prefix(provider_package,
decorator_class_name):
+ return
+
+ if name in self._taskflow_decorators:
+ try:
+ existing = self._taskflow_decorators[name]
+ other_name = f"{existing.__module__}.{existing.__name__}"
+ except Exception:
+ # If problem importing, then get the value from the
functools.partial
+ other_name = self._taskflow_decorators._raw_dict[name].args[0]
# type: ignore[attr-defined]
+
+ log.warning(
+ "The taskflow decorator '%s' has been already registered (by
%s).",
+ name,
+ other_name,
+ )
+ return
+
+ self._taskflow_decorators[name] = functools.partial(import_string,
decorator_class_name)
+
+ @property
+ def providers(self) -> dict[str, ProviderInfo]:
+ """Returns information about available providers."""
+ self.initialize_providers_list()
+ return self._provider_dict
+
+ @property
+ def hooks(self) -> MutableMapping[str, HookInfo | None]:
+ """
+ Return dictionary of connection_type-to-hook mapping.
+
+ Note that the dict can contain None values if a hook discovered cannot
be imported!
+ """
+ self.initialize_providers_hooks()
+ return self._hooks_lazy_dict
+
+ @property
+ def taskflow_decorators(self) -> dict[str, TaskDecorator]:
+ self.initialize_providers_taskflow_decorator()
+ return self._taskflow_decorators # type: ignore[return-value]
+
+ @property
+ def filesystem_module_names(self) -> list[str]:
+ self.initialize_providers_filesystems()
+ return sorted(self._fs_set)
+
+ @property
+ def asset_factories(self) -> dict[str, Callable[..., Asset]]:
+ self.initialize_providers_asset_uri_resources()
+ return self._asset_factories
+
+ @property
+ def asset_uri_handlers(self) -> dict[str, Callable[[SplitResult],
SplitResult]]:
+ self.initialize_providers_asset_uri_resources()
+ return self._asset_uri_handlers
+
+ @property
+ def asset_to_openlineage_converters(
+ self,
+ ) -> dict[str, Callable]:
+ self.initialize_providers_asset_uri_resources()
+ return self._asset_to_openlineage_converters
+
+ @property
+ def plugins(self) -> list[PluginInfo]:
+ """Returns information about plugins available in providers."""
+ self.initialize_providers_plugins()
+ return sorted(self._plugins_set, key=lambda x: x.plugin_class)
+
+ def _cleanup(self):
+ self._initialized_cache.clear()
+ self._provider_dict.clear()
+ self._fs_set.clear()
+ self._taskflow_decorators.clear()
+ self._hook_provider_dict.clear()
+ self._hooks_lazy_dict.clear()
+ self._plugins_set.clear()
+ self._asset_uri_handlers.clear()
+ self._asset_factories.clear()
+ self._asset_to_openlineage_converters.clear()
+
+ self._initialized = False
+ self._initialization_stack_trace = None
diff --git a/task-sdk/tests/task_sdk/definitions/test_connection.py
b/task-sdk/tests/task_sdk/definitions/test_connection.py
index 8fca258ec79..d7811f491c9 100644
--- a/task-sdk/tests/task_sdk/definitions/test_connection.py
+++ b/task-sdk/tests/task_sdk/definitions/test_connection.py
@@ -36,7 +36,7 @@ class TestConnections:
@pytest.fixture
def mock_providers_manager(self):
"""Mock the ProvidersManager to return predefined hooks."""
- with mock.patch("airflow.providers_manager.ProvidersManager") as
mock_manager:
+ with
mock.patch("airflow.sdk.definitions.connection.ProvidersManagerTaskRuntime") as
mock_manager:
yield mock_manager
@mock.patch("airflow.sdk._shared.module_loading.import_string")
diff --git a/task-sdk/tests/task_sdk/docs/test_public_api.py
b/task-sdk/tests/task_sdk/docs/test_public_api.py
index e7f653d76a2..2186cd32b12 100644
--- a/task-sdk/tests/task_sdk/docs/test_public_api.py
+++ b/task-sdk/tests/task_sdk/docs/test_public_api.py
@@ -61,6 +61,7 @@ def test_airflow_sdk_no_unexpected_exports():
"observability",
"plugins_manager",
"listener",
+ "providers_manager_runtime",
}
unexpected = actual - public - ignore
assert not unexpected, f"Unexpected exports in airflow.sdk:
{sorted(unexpected)}"
diff --git a/task-sdk/tests/task_sdk/test_providers_manager_runtime.py
b/task-sdk/tests/task_sdk/test_providers_manager_runtime.py
new file mode 100644
index 00000000000..da6600a6fda
--- /dev/null
+++ b/task-sdk/tests/task_sdk/test_providers_manager_runtime.py
@@ -0,0 +1,238 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+import logging
+import sys
+import warnings
+from unittest.mock import patch
+
+import pytest
+
+from airflow.exceptions import AirflowOptionalProviderFeatureException
+from airflow.sdk._shared.providers_discovery import (
+ HookClassProvider,
+ LazyDictWithCache,
+ ProviderInfo,
+)
+from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime
+
+from tests_common.test_utils.markers import
skip_if_force_lowest_dependencies_marker, skip_if_not_on_main
+from tests_common.test_utils.paths import AIRFLOW_ROOT_PATH
+
+PY313 = sys.version_info >= (3, 13)
+
+
+def test_cleanup_providers_manager_runtime(cleanup_providers_manager):
+ """Check the cleanup provider manager functionality."""
+ provider_manager = ProvidersManagerTaskRuntime()
+ # Check by type name since symlinks create different module paths
+ assert type(provider_manager.hooks).__name__ == "LazyDictWithCache"
+ hooks = provider_manager.hooks
+ ProvidersManagerTaskRuntime()._cleanup()
+ assert not len(hooks)
+ assert ProvidersManagerTaskRuntime().hooks is hooks
+
+
+@skip_if_force_lowest_dependencies_marker
+class TestProvidersManagerRuntime:
+ @pytest.fixture(autouse=True)
+ def inject_fixtures(self, caplog, cleanup_providers_manager_runtime):
+ self._caplog = caplog
+
+ def test_hooks_deprecation_warnings_generated(self):
+ providers_manager = ProvidersManagerTaskRuntime()
+ providers_manager._provider_dict["test-package"] = ProviderInfo(
+ version="0.0.1",
+ data={"hook-class-names":
["airflow.providers.sftp.hooks.sftp.SFTPHook"]},
+ )
+ with pytest.warns(expected_warning=DeprecationWarning,
match="hook-class-names") as warning_records:
+ providers_manager._discover_hooks()
+ assert warning_records
+
+ def test_hooks_deprecation_warnings_not_generated(self):
+ with warnings.catch_warnings(record=True) as warning_records:
+ providers_manager = ProvidersManagerTaskRuntime()
+ providers_manager._provider_dict["apache-airflow-providers-sftp"]
= ProviderInfo(
+ version="0.0.1",
+ data={
+ "hook-class-names":
["airflow.providers.sftp.hooks.sftp.SFTPHook"],
+ "connection-types": [
+ {
+ "hook-class-name":
"airflow.providers.sftp.hooks.sftp.SFTPHook",
+ "connection-type": "sftp",
+ }
+ ],
+ },
+ )
+ providers_manager._discover_hooks()
+ assert [w.message for w in warning_records if "hook-class-names" in
str(w.message)] == []
+
+ def test_warning_logs_generated(self):
+ providers_manager = ProvidersManagerTaskRuntime()
+ providers_manager._hooks_lazy_dict = LazyDictWithCache()
+ with self._caplog.at_level(logging.WARNING):
+ providers_manager._provider_dict["apache-airflow-providers-sftp"]
= ProviderInfo(
+ version="0.0.1",
+ data={
+ "hook-class-names":
["airflow.providers.sftp.hooks.sftp.SFTPHook"],
+ "connection-types": [
+ {
+ "hook-class-name":
"airflow.providers.sftp.hooks.sftp.SFTPHook",
+ "connection-type": "wrong-connection-type",
+ }
+ ],
+ },
+ )
+ providers_manager._discover_hooks()
+ _ = providers_manager._hooks_lazy_dict["wrong-connection-type"]
+ assert len(self._caplog.entries) == 1
+ assert "Inconsistency!" in self._caplog[0]["event"]
+ assert "sftp" not in providers_manager._hooks_lazy_dict
+
+ def test_warning_logs_not_generated(self):
+ with self._caplog.at_level(logging.WARNING):
+ providers_manager = ProvidersManagerTaskRuntime()
+ providers_manager._provider_dict["apache-airflow-providers-sftp"]
= ProviderInfo(
+ version="0.0.1",
+ data={
+ "hook-class-names":
["airflow.providers.sftp.hooks.sftp.SFTPHook"],
+ "connection-types": [
+ {
+ "hook-class-name":
"airflow.providers.sftp.hooks.sftp.SFTPHook",
+ "connection-type": "sftp",
+ }
+ ],
+ },
+ )
+ providers_manager._discover_hooks()
+ _ = providers_manager._hooks_lazy_dict["sftp"]
+ assert not self._caplog.records
+ assert "sftp" in providers_manager.hooks
+
+ def test_already_registered_conn_type_in_provide(self):
+ with self._caplog.at_level(logging.WARNING):
+ providers_manager = ProvidersManagerTaskRuntime()
+ providers_manager._provider_dict["apache-airflow-providers-dummy"]
= ProviderInfo(
+ version="0.0.1",
+ data={
+ "connection-types": [
+ {
+ "hook-class-name":
"airflow.providers.dummy.hooks.dummy.DummyHook",
+ "connection-type": "dummy",
+ },
+ {
+ "hook-class-name":
"airflow.providers.dummy.hooks.dummy.DummyHook2",
+ "connection-type": "dummy",
+ },
+ ],
+ },
+ )
+ providers_manager._discover_hooks()
+ _ = providers_manager._hooks_lazy_dict["dummy"]
+ assert len(self._caplog.records) == 1
+ msg = self._caplog.messages[0]
+ assert msg.startswith("The connection type 'dummy' is already
registered")
+ assert (
+ "different class names:
'airflow.providers.dummy.hooks.dummy.DummyHook'"
+ " and 'airflow.providers.dummy.hooks.dummy.DummyHook2'."
+ ) in msg
+
+ def test_hooks(self):
+ with warnings.catch_warnings(record=True) as warning_records:
+ with self._caplog.at_level(logging.WARNING):
+ provider_manager = ProvidersManagerTaskRuntime()
+ connections_list = list(provider_manager.hooks.keys())
+ assert len(connections_list) > 60
+ if len(self._caplog.records) != 0:
+ for record in self._caplog.records:
+ print(record.message, file=sys.stderr)
+ print(record.exc_info, file=sys.stderr)
+ raise AssertionError("There are warnings generated during hook
imports. Please fix them")
+ assert [w.message for w in warning_records if "hook-class-names" in
str(w.message)] == []
+
+ @skip_if_not_on_main
+ @pytest.mark.execution_timeout(150)
+ def test_hook_values(self):
+ provider_dependencies = json.loads(
+ (AIRFLOW_ROOT_PATH / "generated" /
"provider_dependencies.json").read_text()
+ )
+ python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
+ excluded_providers: list[str] = []
+ for provider_name, provider_info in provider_dependencies.items():
+ if python_version in provider_info.get("excluded-python-versions",
[]):
+
excluded_providers.append(f"apache-airflow-providers-{provider_name.replace('.',
'-')}")
+ with warnings.catch_warnings(record=True) as warning_records:
+ with self._caplog.at_level(logging.WARNING):
+ provider_manager = ProvidersManagerTaskRuntime()
+ connections_list = list(provider_manager.hooks.values())
+ assert len(connections_list) > 60
+ if len(self._caplog.records) != 0:
+ real_warning_count = 0
+ for record in self._caplog.entries:
+ # When there is error importing provider that is excluded the
provider name is in the message
+ if any(excluded_provider in record["event"] for
excluded_provider in excluded_providers):
+ continue
+ print(record["event"], file=sys.stderr)
+ print(record.get("exc_info"), file=sys.stderr)
+ real_warning_count += 1
+ if real_warning_count:
+ if PY313:
+ only_ydb_and_yandexcloud_warnings = True
+ for record in warning_records:
+ if "ydb" in str(record.message) or "yandexcloud" in
str(record.message):
+ continue
+ only_ydb_and_yandexcloud_warnings = False
+ if only_ydb_and_yandexcloud_warnings:
+ print(
+ "Only warnings from ydb and yandexcloud providers
are generated, "
+ "which is expected in Python 3.13+",
+ file=sys.stderr,
+ )
+ return
+ raise AssertionError("There are warnings generated during hook
imports. Please fix them")
+ assert [w.message for w in warning_records if "hook-class-names" in
str(w.message)] == []
+
+ @patch("airflow.sdk.providers_manager_runtime.import_string")
+ def test_optional_feature_no_warning(self, mock_importlib_import_string):
+ with self._caplog.at_level(logging.WARNING):
+ mock_importlib_import_string.side_effect =
AirflowOptionalProviderFeatureException()
+ providers_manager = ProvidersManagerTaskRuntime()
+ providers_manager._hook_provider_dict["test_connection"] =
HookClassProvider(
+ package_name="test_package", hook_class_name="HookClass"
+ )
+ providers_manager._import_hook(
+ hook_class_name=None, provider_info=None, package_name=None,
connection_type="test_connection"
+ )
+ assert self._caplog.messages == []
+
+ @patch("airflow.sdk.providers_manager_runtime.import_string")
+ def test_optional_feature_debug(self, mock_importlib_import_string):
+ with self._caplog.at_level(logging.INFO):
+ mock_importlib_import_string.side_effect =
AirflowOptionalProviderFeatureException()
+ providers_manager = ProvidersManagerTaskRuntime()
+ providers_manager._hook_provider_dict["test_connection"] =
HookClassProvider(
+ package_name="test_package", hook_class_name="HookClass"
+ )
+ providers_manager._import_hook(
+ hook_class_name=None, provider_info=None, package_name=None,
connection_type="test_connection"
+ )
+ assert self._caplog.messages == [
+ "Optional provider feature disabled when importing 'HookClass'
from 'test_package' package"
+ ]