This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 387126a1524 implement get hook in task sdk connections (#47401)
387126a1524 is described below

commit 387126a15246e8f90712297b633a54890a01e04f
Author: Rahul Vats <[email protected]>
AuthorDate: Thu Mar 6 22:55:34 2025 +0530

    implement get hook in task sdk connections (#47401)
---
 task_sdk/src/airflow/sdk/definitions/connection.py | 25 +++++++-
 task_sdk/tests/definitions/test_connections.py     | 72 ++++++++++++++++++++++
 2 files changed, 96 insertions(+), 1 deletion(-)

diff --git a/task_sdk/src/airflow/sdk/definitions/connection.py 
b/task_sdk/src/airflow/sdk/definitions/connection.py
index 8d52638a604..5a447895679 100644
--- a/task_sdk/src/airflow/sdk/definitions/connection.py
+++ b/task_sdk/src/airflow/sdk/definitions/connection.py
@@ -24,6 +24,8 @@ from typing import Any
 
 import attrs
 
+from airflow.exceptions import AirflowException
+
 log = logging.getLogger(__name__)
 
 
@@ -56,7 +58,28 @@ class Connection:
 
     def get_uri(self): ...
 
-    def get_hook(self): ...
+    def get_hook(self, *, hook_params=None):
+        """Return hook based on conn_type."""
+        from airflow.providers_manager import ProvidersManager
+        from airflow.utils.module_loading import import_string
+
+        hook = ProvidersManager().hooks.get(self.conn_type, None)
+
+        if hook is None:
+            raise AirflowException(f'Unknown hook type "{self.conn_type}"')
+        try:
+            hook_class = import_string(hook.hook_class_name)
+        except ImportError:
+            log.error(
+                "Could not import %s when discovering %s %s",
+                hook.hook_class_name,
+                hook.hook_name,
+                hook.package_name,
+            )
+            raise
+        if hook_params is None:
+            hook_params = {}
+        return hook_class(**{hook.connection_id_attribute_name: self.conn_id}, 
**hook_params)
 
     @classmethod
     def get(cls, conn_id: str) -> Any:
diff --git a/task_sdk/tests/definitions/test_connections.py 
b/task_sdk/tests/definitions/test_connections.py
new file mode 100644
index 00000000000..e1ebde7f54f
--- /dev/null
+++ b/task_sdk/tests/definitions/test_connections.py
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.sdk import Connection
+
+
+class TestConnections:
+    @pytest.fixture
+    def mock_providers_manager(self):
+        """Mock the ProvidersManager to return predefined hooks."""
+        with mock.patch("airflow.providers_manager.ProvidersManager") as 
mock_manager:
+            yield mock_manager
+
+    @mock.patch("airflow.utils.module_loading.import_string")
+    def test_get_hook(self, mock_import_string, mock_providers_manager):
+        """Test that get_hook returns the correct hook instance."""
+
+        mock_hook_class = mock.MagicMock()
+        mock_hook_class.return_value = "mock_hook_instance"
+        mock_import_string.return_value = mock_hook_class
+
+        mock_hook = mock.MagicMock()
+        mock_hook.hook_class_name = 
"airflow.providers.mysql.hooks.mysql.MySqlHook"
+        mock_hook.connection_id_attribute_name = "conn_id"
+
+        mock_providers_manager.return_value.hooks = {"mysql": mock_hook}
+
+        conn = Connection(
+            conn_id="test_conn",
+            conn_type="mysql",
+        )
+
+        hook_instance = conn.get_hook(hook_params={"param1": "value1"})
+
+        
mock_import_string.assert_called_once_with("airflow.providers.mysql.hooks.mysql.MySqlHook")
+
+        mock_hook_class.assert_called_once_with(conn_id="test_conn", 
param1="value1")
+
+        assert hook_instance == "mock_hook_instance"
+
+    def test_get_hook_invalid_type(self, mock_providers_manager):
+        """Test that get_hook raises AirflowException for unknown hook type."""
+        mock_providers_manager.return_value.hooks = {}
+
+        conn = Connection(
+            conn_id="test_conn",
+            conn_type="unknown_type",
+        )
+
+        with pytest.raises(AirflowException, match='Unknown hook type 
"unknown_type"'):
+            conn.get_hook()

Reply via email to