This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3461a23e77 Rework provider manager to treat Airflow core hooks like 
other provider hooks (#33051)
3461a23e77 is described below

commit 3461a23e773deb101f536c4df98bd9e1065137c1
Author: Jens Scheffler <[email protected]>
AuthorDate: Sun Aug 6 16:42:30 2023 +0200

    Rework provider manager to treat Airflow core hooks like other provider 
hooks (#33051)
---
 airflow/cli/commands/connection_command.py |  2 +-
 airflow/hooks/filesystem.py                | 42 ++++++++++++++++++++++++++++--
 airflow/providers_manager.py               | 32 +++++++++++++++++++++++
 airflow/www/forms.py                       |  4 ---
 tests/always/test_connection.py            |  6 ++---
 5 files changed, 76 insertions(+), 10 deletions(-)

diff --git a/airflow/cli/commands/connection_command.py 
b/airflow/cli/commands/connection_command.py
index 02251a70ec..3a6909bc15 100644
--- a/airflow/cli/commands/connection_command.py
+++ b/airflow/cli/commands/connection_command.py
@@ -150,7 +150,7 @@ def _valid_uri(uri: str) -> bool:
 @cache
 def _get_connection_types() -> list[str]:
     """Returns connection types available."""
-    _connection_types = ["fs", "mesos_framework-id", "email", "generic"]
+    _connection_types = []
     providers_manager = ProvidersManager()
     for connection_type, provider_info in providers_manager.hooks.items():
         if provider_info:
diff --git a/airflow/hooks/filesystem.py b/airflow/hooks/filesystem.py
index 39517e8cdc..d436e0050c 100644
--- a/airflow/hooks/filesystem.py
+++ b/airflow/hooks/filesystem.py
@@ -17,6 +17,9 @@
 # under the License.
 from __future__ import annotations
 
+from pathlib import Path
+from typing import Any
+
 from airflow.hooks.base import BaseHook
 
 
@@ -33,9 +36,32 @@ class FSHook(BaseHook):
     Extra: {"path": "/tmp"}
     """
 
-    def __init__(self, conn_id: str = "fs_default"):
+    conn_name_attr = "fs_conn_id"
+    default_conn_name = "fs_default"
+    conn_type = "fs"
+    hook_name = "File (path)"
+
+    @staticmethod
+    def get_connection_form_widgets() -> dict[str, Any]:
+        """Returns connection widgets to add to connection form."""
+        from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
+        from flask_babel import lazy_gettext
+        from wtforms import StringField
+
+        return {"path": StringField(lazy_gettext("Path"), 
widget=BS3TextFieldWidget())}
+
+    @staticmethod
+    def get_ui_field_behaviour() -> dict[str, Any]:
+        """Returns custom field behaviour."""
+        return {
+            "hidden_fields": ["host", "schema", "port", "login", "password", 
"extra"],
+            "relabeling": {},
+            "placeholders": {},
+        }
+
+    def __init__(self, fs_conn_id: str = default_conn_name):
         super().__init__()
-        conn = self.get_connection(conn_id)
+        conn = self.get_connection(fs_conn_id)
         self.basepath = conn.extra_dejson.get("path", "")
         self.conn = conn
 
@@ -49,3 +75,15 @@ class FSHook(BaseHook):
         :return: the path.
         """
         return self.basepath
+
+    def test_connection(self):
+        """Test File connection."""
+        try:
+            p = self.get_path()
+            if not p:
+                return False, "File Path is undefined."
+            if not Path(p).exists():
+                return False, f"Path {p} does not exist."
+            return True, f"Path {p} is existing."
+        except Exception as e:
+            return False, str(e)
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index b7689a92da..bba64e316e 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -36,6 +36,7 @@ from typing import TYPE_CHECKING, Any, Callable, 
MutableMapping, NamedTuple, Typ
 from packaging.utils import canonicalize_name
 
 from airflow.exceptions import AirflowOptionalProviderFeatureException
+from airflow.hooks.filesystem import FSHook
 from airflow.typing_compat import Literal
 from airflow.utils import yaml
 from airflow.utils.entry_points import entry_points_with_dist
@@ -431,6 +432,37 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
         )
         # Set of plugins contained in providers
         self._plugins_set: set[PluginInfo] = set()
+        self._init_airflow_core_hooks()
+
+    def _init_airflow_core_hooks(self):
+        """Initializes the hooks dict with default hooks from Airflow core."""
+        core_dummy_hooks = {
+            "generic": "Generic",
+            "email": "Email",
+            "mesos_framework-id": "Mesos Framework ID",
+        }
+        for key, display in core_dummy_hooks.items():
+            self._hooks_lazy_dict[key] = HookInfo(
+                hook_class_name=None,
+                connection_id_attribute_name=None,
+                package_name=None,
+                hook_name=display,
+                connection_type=None,
+                connection_testable=False,
+            )
+        for cls in [FSHook]:
+            package_name = cls.__module__
+            hook_class_name = f"{cls.__module__}.{cls.__name__}"
+            hook_info = self._import_hook(
+                connection_type=None,
+                provider_info=None,
+                hook_class_name=hook_class_name,
+                package_name=package_name,
+            )
+            self._hook_provider_dict[hook_info.connection_type] = 
HookClassProvider(
+                hook_class_name=hook_class_name, package_name=package_name
+            )
+            self._hooks_lazy_dict[hook_info.connection_type] = hook_info
 
     @provider_info_cache("list")
     def initialize_providers_list(self):
diff --git a/airflow/www/forms.py b/airflow/www/forms.py
index 4a0213945c..cc16229ac9 100644
--- a/airflow/www/forms.py
+++ b/airflow/www/forms.py
@@ -189,10 +189,6 @@ def create_connection_form_class() -> type[DynamicForm]:
 
     def _iter_connection_types() -> Iterator[tuple[str, str]]:
         """List available connection types."""
-        yield ("email", "Email")
-        yield ("fs", "File (path)")
-        yield ("generic", "Generic")
-        yield ("mesos_framework-id", "Mesos Framework ID")
         for connection_type, provider_info in providers_manager.hooks.items():
             if provider_info:
                 yield (connection_type, provider_info.hook_name)
diff --git a/tests/always/test_connection.py b/tests/always/test_connection.py
index b367e98972..390f05133a 100644
--- a/tests/always/test_connection.py
+++ b/tests/always/test_connection.py
@@ -753,14 +753,14 @@ class TestConnection:
     @mock.patch.dict(
         "os.environ",
         {
-            "AIRFLOW_CONN_TEST_URI_NO_HOOK": "fs://",
+            "AIRFLOW_CONN_TEST_URI_NO_HOOK": "unknown://",
         },
     )
     def test_connection_test_no_hook(self):
-        conn = Connection(conn_id="test_uri_no_hook", conn_type="fs")
+        conn = Connection(conn_id="test_uri_no_hook", conn_type="unknown")
         res = conn.test_connection()
         assert res[0] is False
-        assert res[1] == 'Unknown hook type "fs"'
+        assert res[1] == 'Unknown hook type "unknown"'
 
     @mock.patch.dict(
         "os.environ",

Reply via email to