This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 18ef6875663 Added notion of dialects into ProvidersManager (#43726)
18ef6875663 is described below
commit 18ef68756638c6aa208f46dbd6bd0811729c079d
Author: David Blain <[email protected]>
AuthorDate: Mon Nov 11 22:40:44 2024 +0100
Added notion of dialects into ProvidersManager (#43726)
---------
Co-authored-by: David Blain <[email protected]>
---
airflow/cli/commands/provider_command.py | 13 +++++++++++++
airflow/provider.yaml.schema.json | 21 ++++++++++++++++++++
airflow/providers_manager.py | 33 ++++++++++++++++++++++++++++++++
tests/always/test_providers_manager.py | 32 +++++++++++++++++++++++++++++++
4 files changed, 99 insertions(+)
diff --git a/airflow/cli/commands/provider_command.py
b/airflow/cli/commands/provider_command.py
index 59b889564cb..c408fea9558 100644
--- a/airflow/cli/commands/provider_command.py
+++ b/airflow/cli/commands/provider_command.py
@@ -88,6 +88,19 @@ def hooks_list(args):
)
+@suppress_logs_and_warning
+@providers_configuration_loaded
+def dialects_list(args):
+ AirflowConsole().print_as(
+ data=list(ProvidersManager().dialects.values()),
+ output=args.output,
+ mapper=lambda x: {
+ "dialect_name": x.name,
+ "class": x.dialect_class_name,
+ },
+ )
+
+
@suppress_logs_and_warning
@providers_configuration_loaded
def triggers_list(args):
diff --git a/airflow/provider.yaml.schema.json
b/airflow/provider.yaml.schema.json
index 9e7edc449a6..0b770ee2b3f 100644
--- a/airflow/provider.yaml.schema.json
+++ b/airflow/provider.yaml.schema.json
@@ -158,6 +158,27 @@
"additionalProperties": true
}
},
+ "dialects": {
+ "type": "array",
+ "description": "Array of dialects mapped to dialect class names",
+ "items": {
+ "type": "object",
+ "properties": {
+ "dialect-type": {
+ "description": "Type of dialect defined by the
provider",
+ "type": "string"
+ },
+ "dialect-class-name": {
+ "description": "Dialect class name that implements the
dialect type",
+ "type": "string"
+ }
+ },
+ "required": [
+ "dialect-type",
+ "dialect-class-name"
+ ]
+ }
+ },
"hooks": {
"type": "array",
"items": {
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index 3c9f225f8e9..e5c02d0113e 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -218,6 +218,14 @@ class HookClassProvider(NamedTuple):
package_name: str
+class DialectInfo(NamedTuple):
+ """Dialect class and Provider it comes from."""
+
+ name: str
+ dialect_class_name: str
+ provider_name: str
+
+
class TriggerInfo(NamedTuple):
"""Trigger class and provider it comes from."""
@@ -250,6 +258,7 @@ class HookInfo(NamedTuple):
hook_name: str
connection_type: str
connection_testable: bool
+ dialects: list[str] = []
class ConnectionFormWidgetInfo(NamedTuple):
@@ -428,6 +437,7 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache()
# type: ignore[assignment]
# keeps mapping between connection_types and hook class, package they
come from
self._hook_provider_dict: dict[str, HookClassProvider] = {}
+ self._dialect_provider_dict: dict[str, DialectInfo] = {}
# Keeps dict of hooks keyed by connection type. They are lazy
evaluated at access time
self._hooks_lazy_dict: LazyDictWithCache[str, HookInfo | Callable] =
LazyDictWithCache()
# Keeps methods that should be used to add custom widgets tuple of
keyed by name of the extra field
@@ -844,6 +854,7 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
for package_name, provider in self._provider_dict.items():
duplicated_connection_types: set[str] = set()
hook_class_names_registered: set[str] = set()
+ self._discover_provider_dialects(package_name, provider)
provider_uses_connection_types =
self._discover_hooks_from_connection_types(
hook_class_names_registered, duplicated_connection_types,
package_name, provider
)
@@ -856,6 +867,20 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
)
self._hook_provider_dict =
dict(sorted(self._hook_provider_dict.items()))
+ def _discover_provider_dialects(self, provider_name: str, provider:
ProviderInfo):
+ dialects = provider.data.get("dialects", [])
+ if dialects:
+ self._dialect_provider_dict.update(
+ {
+ item["dialect-type"]: DialectInfo(
+ name=item["dialect-type"],
+ dialect_class_name=item["dialect-class-name"],
+ provider_name=provider_name,
+ )
+ for item in dialects
+ }
+ )
+
@provider_info_cache("import_all_hooks")
def _import_info_from_all_hooks(self):
"""Force-import all hooks and initialize the connections/fields."""
@@ -1257,6 +1282,13 @@ class ProvidersManager(LoggingMixin,
metaclass=Singleton):
# When we return hooks here it will only be used to retrieve hook
information
return self._hooks_lazy_dict
+ @property
+ def dialects(self) -> MutableMapping[str, DialectInfo]:
+ """Return dictionary of connection_type-to-dialect mapping."""
+ self.initialize_providers_hooks()
+ # When we return dialects here it will only be used to retrieve
dialect information
+ return self._dialect_provider_dict
+
@property
def plugins(self) -> list[PluginInfo]:
"""Returns information about plugins available in providers."""
@@ -1353,6 +1385,7 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
self._fs_set.clear()
self._taskflow_decorators.clear()
self._hook_provider_dict.clear()
+ self._dialect_provider_dict.clear()
self._hooks_lazy_dict.clear()
self._connection_form_widgets.clear()
self._field_behaviours.clear()
diff --git a/tests/always/test_providers_manager.py
b/tests/always/test_providers_manager.py
index d1d22bd8c85..3558a8cbdd7 100644
--- a/tests/always/test_providers_manager.py
+++ b/tests/always/test_providers_manager.py
@@ -32,6 +32,7 @@ from wtforms import BooleanField, Field, StringField
from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers_manager import (
+ DialectInfo,
HookClassProvider,
LazyDictWithCache,
PluginInfo,
@@ -197,6 +198,32 @@ class TestProviderManager:
provider_name="apache-airflow-providers-apache-hive",
)
+ def test_providers_manager_register_dialects(self):
+ providers_manager = ProvidersManager()
+ providers_manager._provider_dict = LazyDictWithCache()
+ providers_manager._provider_dict["airflow.providers.common.sql"] =
ProviderInfo(
+ version="1.19.0",
+ data={
+ "dialects": [
+ {
+ "dialect-type": "default",
+ "dialect-class-name":
"airflow.providers.common.sql.dialects.dialect.Dialect",
+ }
+ ]
+ },
+ package_or_source="package",
+ )
+ providers_manager._discover_hooks()
+ assert len(providers_manager._dialect_provider_dict) == 1
+ assert providers_manager._dialect_provider_dict.popitem() == (
+ "default",
+ DialectInfo(
+ name="default",
+
dialect_class_name="airflow.providers.common.sql.dialects.dialect.Dialect",
+ provider_name="airflow.providers.common.sql",
+ ),
+ )
+
def test_hooks(self):
with warnings.catch_warnings(record=True) as warning_records:
with self._caplog.at_level(logging.WARNING):
@@ -416,6 +443,11 @@ class TestProviderManager:
auth_manager_class_names = list(provider_manager.auth_managers)
assert len(auth_manager_class_names) > 0
+ def test_dialects(self):
+ provider_manager = ProvidersManager()
+ dialect_class_names = list(provider_manager.dialects)
+ assert len(dialect_class_names) == 0
+
@patch("airflow.providers_manager.import_string")
def test_optional_feature_no_warning(self, mock_importlib_import_string):
with self._caplog.at_level(logging.WARNING):