This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 18ef6875663 Added notion of dialects into ProvidersManager (#43726)
18ef6875663 is described below

commit 18ef68756638c6aa208f46dbd6bd0811729c079d
Author: David Blain <[email protected]>
AuthorDate: Mon Nov 11 22:40:44 2024 +0100

    Added notion of dialects into ProvidersManager (#43726)
    
    
    
    ---------
    
    Co-authored-by: David Blain <[email protected]>
---
 airflow/cli/commands/provider_command.py | 13 +++++++++++++
 airflow/provider.yaml.schema.json        | 21 ++++++++++++++++++++
 airflow/providers_manager.py             | 33 ++++++++++++++++++++++++++++++++
 tests/always/test_providers_manager.py   | 32 +++++++++++++++++++++++++++++++
 4 files changed, 99 insertions(+)

diff --git a/airflow/cli/commands/provider_command.py 
b/airflow/cli/commands/provider_command.py
index 59b889564cb..c408fea9558 100644
--- a/airflow/cli/commands/provider_command.py
+++ b/airflow/cli/commands/provider_command.py
@@ -88,6 +88,19 @@ def hooks_list(args):
     )
 
 
+@suppress_logs_and_warning
+@providers_configuration_loaded
+def dialects_list(args):
+    AirflowConsole().print_as(
+        data=list(ProvidersManager().dialects.values()),
+        output=args.output,
+        mapper=lambda x: {
+            "dialect_name": x.name,
+            "class": x.dialect_class_name,
+        },
+    )
+
+
 @suppress_logs_and_warning
 @providers_configuration_loaded
 def triggers_list(args):
diff --git a/airflow/provider.yaml.schema.json 
b/airflow/provider.yaml.schema.json
index 9e7edc449a6..0b770ee2b3f 100644
--- a/airflow/provider.yaml.schema.json
+++ b/airflow/provider.yaml.schema.json
@@ -158,6 +158,27 @@
                 "additionalProperties": true
             }
         },
+        "dialects": {
+            "type": "array",
+            "description": "Array of dialects mapped to dialect class names",
+            "items": {
+                "type": "object",
+                "properties": {
+                    "dialect-type": {
+                        "description": "Type of dialect defined by the 
provider",
+                        "type": "string"
+                    },
+                    "dialect-class-name": {
+                        "description": "Dialect class name that implements the 
dialect type",
+                        "type": "string"
+                    }
+                },
+                "required": [
+                    "dialect-type",
+                    "dialect-class-name"
+                ]
+            }
+        },
         "hooks": {
             "type": "array",
             "items": {
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index 3c9f225f8e9..e5c02d0113e 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -218,6 +218,14 @@ class HookClassProvider(NamedTuple):
     package_name: str
 
 
+class DialectInfo(NamedTuple):
+    """Dialect class and Provider it comes from."""
+
+    name: str
+    dialect_class_name: str
+    provider_name: str
+
+
 class TriggerInfo(NamedTuple):
     """Trigger class and provider it comes from."""
 
@@ -250,6 +258,7 @@ class HookInfo(NamedTuple):
     hook_name: str
     connection_type: str
     connection_testable: bool
+    dialects: list[str] = []
 
 
 class ConnectionFormWidgetInfo(NamedTuple):
@@ -428,6 +437,7 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
         self._taskflow_decorators: dict[str, Callable] = LazyDictWithCache()  
# type: ignore[assignment]
         # keeps mapping between connection_types and hook class, package they 
come from
         self._hook_provider_dict: dict[str, HookClassProvider] = {}
+        self._dialect_provider_dict: dict[str, DialectInfo] = {}
         # Keeps dict of hooks keyed by connection type. They are lazy 
evaluated at access time
         self._hooks_lazy_dict: LazyDictWithCache[str, HookInfo | Callable] = 
LazyDictWithCache()
         # Keeps methods that should be used to add custom widgets tuple of 
keyed by name of the extra field
@@ -844,6 +854,7 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
         for package_name, provider in self._provider_dict.items():
             duplicated_connection_types: set[str] = set()
             hook_class_names_registered: set[str] = set()
+            self._discover_provider_dialects(package_name, provider)
             provider_uses_connection_types = 
self._discover_hooks_from_connection_types(
                 hook_class_names_registered, duplicated_connection_types, 
package_name, provider
             )
@@ -856,6 +867,20 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
             )
         self._hook_provider_dict = 
dict(sorted(self._hook_provider_dict.items()))
 
+    def _discover_provider_dialects(self, provider_name: str, provider: 
ProviderInfo):
+        dialects = provider.data.get("dialects", [])
+        if dialects:
+            self._dialect_provider_dict.update(
+                {
+                    item["dialect-type"]: DialectInfo(
+                        name=item["dialect-type"],
+                        dialect_class_name=item["dialect-class-name"],
+                        provider_name=provider_name,
+                    )
+                    for item in dialects
+                }
+            )
+
     @provider_info_cache("import_all_hooks")
     def _import_info_from_all_hooks(self):
         """Force-import all hooks and initialize the connections/fields."""
@@ -1257,6 +1282,13 @@ class ProvidersManager(LoggingMixin, 
metaclass=Singleton):
         # When we return hooks here it will only be used to retrieve hook 
information
         return self._hooks_lazy_dict
 
+    @property
+    def dialects(self) -> MutableMapping[str, DialectInfo]:
+        """Return dictionary of connection_type-to-dialect mapping."""
+        self.initialize_providers_hooks()
+        # When we return dialects here it will only be used to retrieve 
dialect information
+        return self._dialect_provider_dict
+
     @property
     def plugins(self) -> list[PluginInfo]:
         """Returns information about plugins available in providers."""
@@ -1353,6 +1385,7 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
         self._fs_set.clear()
         self._taskflow_decorators.clear()
         self._hook_provider_dict.clear()
+        self._dialect_provider_dict.clear()
         self._hooks_lazy_dict.clear()
         self._connection_form_widgets.clear()
         self._field_behaviours.clear()
diff --git a/tests/always/test_providers_manager.py 
b/tests/always/test_providers_manager.py
index d1d22bd8c85..3558a8cbdd7 100644
--- a/tests/always/test_providers_manager.py
+++ b/tests/always/test_providers_manager.py
@@ -32,6 +32,7 @@ from wtforms import BooleanField, Field, StringField
 
 from airflow.exceptions import AirflowOptionalProviderFeatureException
 from airflow.providers_manager import (
+    DialectInfo,
     HookClassProvider,
     LazyDictWithCache,
     PluginInfo,
@@ -197,6 +198,32 @@ class TestProviderManager:
             provider_name="apache-airflow-providers-apache-hive",
         )
 
+    def test_providers_manager_register_dialects(self):
+        providers_manager = ProvidersManager()
+        providers_manager._provider_dict = LazyDictWithCache()
+        providers_manager._provider_dict["airflow.providers.common.sql"] = 
ProviderInfo(
+            version="1.19.0",
+            data={
+                "dialects": [
+                    {
+                        "dialect-type": "default",
+                        "dialect-class-name": 
"airflow.providers.common.sql.dialects.dialect.Dialect",
+                    }
+                ]
+            },
+            package_or_source="package",
+        )
+        providers_manager._discover_hooks()
+        assert len(providers_manager._dialect_provider_dict) == 1
+        assert providers_manager._dialect_provider_dict.popitem() == (
+            "default",
+            DialectInfo(
+                name="default",
+                
dialect_class_name="airflow.providers.common.sql.dialects.dialect.Dialect",
+                provider_name="airflow.providers.common.sql",
+            ),
+        )
+
     def test_hooks(self):
         with warnings.catch_warnings(record=True) as warning_records:
             with self._caplog.at_level(logging.WARNING):
@@ -416,6 +443,11 @@ class TestProviderManager:
         auth_manager_class_names = list(provider_manager.auth_managers)
         assert len(auth_manager_class_names) > 0
 
+    def test_dialects(self):
+        provider_manager = ProvidersManager()
+        dialect_class_names = list(provider_manager.dialects)
+        assert len(dialect_class_names) == 0
+
     @patch("airflow.providers_manager.import_string")
     def test_optional_feature_no_warning(self, mock_importlib_import_string):
         with self._caplog.at_level(logging.WARNING):

Reply via email to