This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 387126a1524 implement get hook in task sdk connections (#47401)
387126a1524 is described below
commit 387126a15246e8f90712297b633a54890a01e04f
Author: Rahul Vats <[email protected]>
AuthorDate: Thu Mar 6 22:55:34 2025 +0530
implement get hook in task sdk connections (#47401)
---
task_sdk/src/airflow/sdk/definitions/connection.py | 25 +++++++-
task_sdk/tests/definitions/test_connections.py | 72 ++++++++++++++++++++++
2 files changed, 96 insertions(+), 1 deletion(-)
diff --git a/task_sdk/src/airflow/sdk/definitions/connection.py
b/task_sdk/src/airflow/sdk/definitions/connection.py
index 8d52638a604..5a447895679 100644
--- a/task_sdk/src/airflow/sdk/definitions/connection.py
+++ b/task_sdk/src/airflow/sdk/definitions/connection.py
@@ -24,6 +24,8 @@ from typing import Any
import attrs
+from airflow.exceptions import AirflowException
+
log = logging.getLogger(__name__)
@@ -56,7 +58,28 @@ class Connection:
def get_uri(self): ...
- def get_hook(self): ...
+ def get_hook(self, *, hook_params=None):
+ """Return hook based on conn_type."""
+ from airflow.providers_manager import ProvidersManager
+ from airflow.utils.module_loading import import_string
+
+ hook = ProvidersManager().hooks.get(self.conn_type, None)
+
+ if hook is None:
+ raise AirflowException(f'Unknown hook type "{self.conn_type}"')
+ try:
+ hook_class = import_string(hook.hook_class_name)
+ except ImportError:
+ log.error(
+ "Could not import %s when discovering %s %s",
+ hook.hook_class_name,
+ hook.hook_name,
+ hook.package_name,
+ )
+ raise
+ if hook_params is None:
+ hook_params = {}
+ return hook_class(**{hook.connection_id_attribute_name: self.conn_id},
**hook_params)
@classmethod
def get(cls, conn_id: str) -> Any:
diff --git a/task_sdk/tests/definitions/test_connections.py
b/task_sdk/tests/definitions/test_connections.py
new file mode 100644
index 00000000000..e1ebde7f54f
--- /dev/null
+++ b/task_sdk/tests/definitions/test_connections.py
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.sdk import Connection
+
+
+class TestConnections:
+ @pytest.fixture
+ def mock_providers_manager(self):
+ """Mock the ProvidersManager to return predefined hooks."""
+ with mock.patch("airflow.providers_manager.ProvidersManager") as
mock_manager:
+ yield mock_manager
+
+ @mock.patch("airflow.utils.module_loading.import_string")
+ def test_get_hook(self, mock_import_string, mock_providers_manager):
+ """Test that get_hook returns the correct hook instance."""
+
+ mock_hook_class = mock.MagicMock()
+ mock_hook_class.return_value = "mock_hook_instance"
+ mock_import_string.return_value = mock_hook_class
+
+ mock_hook = mock.MagicMock()
+ mock_hook.hook_class_name =
"airflow.providers.mysql.hooks.mysql.MySqlHook"
+ mock_hook.connection_id_attribute_name = "conn_id"
+
+ mock_providers_manager.return_value.hooks = {"mysql": mock_hook}
+
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="mysql",
+ )
+
+ hook_instance = conn.get_hook(hook_params={"param1": "value1"})
+
+
mock_import_string.assert_called_once_with("airflow.providers.mysql.hooks.mysql.MySqlHook")
+
+ mock_hook_class.assert_called_once_with(conn_id="test_conn",
param1="value1")
+
+ assert hook_instance == "mock_hook_instance"
+
+ def test_get_hook_invalid_type(self, mock_providers_manager):
+ """Test that get_hook raises AirflowException for unknown hook type."""
+ mock_providers_manager.return_value.hooks = {}
+
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="unknown_type",
+ )
+
+ with pytest.raises(AirflowException, match='Unknown hook type
"unknown_type"'):
+ conn.get_hook()