This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1a61eb3afd feat: OpenSearchQueryOperator using an endpoint with a
self-signed certificate (#39788)
1a61eb3afd is described below
commit 1a61eb3afdaf0496ce6c308b1a36357cfd5728b7
Author: Lukas1v <[email protected]>
AuthorDate: Sat Jun 8 07:44:57 2024 +0200
feat: OpenSearchQueryOperator using an endpoint with a self-signed
certificate (#39788)
* feat: added connection options
* feat: opensearch hook unit tests
* feat: fallback to RequestsHttpConnection
* fix: static checks
* fix: static checks
* fix: static checks
* feat: opensearch static module loading
---------
Co-authored-by: Lukas Verret <[email protected]>
---
airflow/providers/opensearch/hooks/opensearch.py | 21 ++++++++++++----
.../providers/opensearch/operators/opensearch.py | 12 +++++++++-
.../providers/opensearch/hooks/test_opensearch.py | 28 ++++++++++++++++++++++
3 files changed, 55 insertions(+), 6 deletions(-)
diff --git a/airflow/providers/opensearch/hooks/opensearch.py
b/airflow/providers/opensearch/hooks/opensearch.py
index 2b4c254b4a..c5500be108 100644
--- a/airflow/providers/opensearch/hooks/opensearch.py
+++ b/airflow/providers/opensearch/hooks/opensearch.py
@@ -19,12 +19,16 @@ from __future__ import annotations
import json
from functools import cached_property
-from typing import Any
+from typing import TYPE_CHECKING, Any
from opensearchpy import OpenSearch, RequestsHttpConnection
+if TYPE_CHECKING:
+ from opensearchpy import Connection as OpenSearchConnectionClass
+
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
+from airflow.utils.strings import to_boolean
class OpenSearchHook(BaseHook):
@@ -40,13 +44,20 @@ class OpenSearchHook(BaseHook):
conn_type = "opensearch"
hook_name = "OpenSearch Hook"
- def __init__(self, open_search_conn_id: str, log_query: bool, **kwargs:
Any):
+ def __init__(
+ self,
+ open_search_conn_id: str,
+ log_query: bool,
+ open_search_conn_class: type[OpenSearchConnectionClass] | None =
RequestsHttpConnection,
+ **kwargs: Any,
+ ):
super().__init__(**kwargs)
self.conn_id = open_search_conn_id
self.log_query = log_query
- self.use_ssl = self.conn.extra_dejson.get("use_ssl", False)
- self.verify_certs = self.conn.extra_dejson.get("verify_certs", False)
+ self.use_ssl = to_boolean(str(self.conn.extra_dejson.get("use_ssl",
False)))
+ self.verify_certs =
to_boolean(str(self.conn.extra_dejson.get("verify_certs", False)))
+ self.connection_class = open_search_conn_class
self.__SERVICE = "es"
@cached_property
@@ -62,7 +73,7 @@ class OpenSearchHook(BaseHook):
http_auth=auth,
use_ssl=self.use_ssl,
verify_certs=self.verify_certs,
- connection_class=RequestsHttpConnection,
+ connection_class=self.connection_class,
)
return client
diff --git a/airflow/providers/opensearch/operators/opensearch.py
b/airflow/providers/opensearch/operators/opensearch.py
index cc12b6e8b0..6599e5bf94 100644
--- a/airflow/providers/opensearch/operators/opensearch.py
+++ b/airflow/providers/opensearch/operators/opensearch.py
@@ -20,6 +20,7 @@ from __future__ import annotations
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
+from opensearchpy import RequestsHttpConnection
from opensearchpy.exceptions import OpenSearchException
from airflow.exceptions import AirflowException
@@ -27,6 +28,8 @@ from airflow.models import BaseOperator
from airflow.providers.opensearch.hooks.opensearch import OpenSearchHook
if TYPE_CHECKING:
+ from opensearchpy import Connection as OpenSearchConnectionClass
+
from airflow.utils.context import Context
@@ -42,6 +45,7 @@ class OpenSearchQueryOperator(BaseOperator):
:param search_object: A Search object from opensearch-dsl.
:param index_name: The name of the index to search for documents.
:param opensearch_conn_id: opensearch connection to use
+ :param opensearch_conn_class: opensearch connection class to use
:param log_query: Whether to log the query used. Defaults to True and logs
query used.
"""
@@ -54,6 +58,7 @@ class OpenSearchQueryOperator(BaseOperator):
search_object: Any | None = None,
index_name: str | None = None,
opensearch_conn_id: str = "opensearch_default",
+ opensearch_conn_class: type[OpenSearchConnectionClass] | None =
RequestsHttpConnection,
log_query: bool = True,
**kwargs,
) -> None:
@@ -61,13 +66,18 @@ class OpenSearchQueryOperator(BaseOperator):
self.query = query
self.index_name = index_name
self.opensearch_conn_id = opensearch_conn_id
+ self.opensearch_conn_class = opensearch_conn_class
self.log_query = log_query
self.search_object = search_object
@cached_property
def hook(self) -> OpenSearchHook:
"""Get an instance of an OpenSearchHook."""
- return OpenSearchHook(open_search_conn_id=self.opensearch_conn_id,
log_query=self.log_query)
+ return OpenSearchHook(
+ open_search_conn_id=self.opensearch_conn_id,
+ open_search_conn_class=self.opensearch_conn_class,
+ log_query=self.log_query,
+ )
def execute(self, context: Context) -> Any:
"""Execute a search against a given index or a Search object on an
OpenSearch Cluster."""
diff --git a/tests/providers/opensearch/hooks/test_opensearch.py
b/tests/providers/opensearch/hooks/test_opensearch.py
index 92f57d276e..84360ae73f 100644
--- a/tests/providers/opensearch/hooks/test_opensearch.py
+++ b/tests/providers/opensearch/hooks/test_opensearch.py
@@ -16,15 +16,21 @@
# under the License.
from __future__ import annotations
+from unittest import mock
+
+import opensearchpy
import pytest
+from opensearchpy import Urllib3HttpConnection
from airflow.exceptions import AirflowException
+from airflow.models import Connection
from airflow.providers.opensearch.hooks.opensearch import OpenSearchHook
pytestmark = pytest.mark.db_test
MOCK_SEARCH_RETURN = {"status": "test"}
+DEFAULT_CONN = opensearchpy.connection.http_requests.RequestsHttpConnection
class TestOpenSearchHook:
@@ -46,3 +52,25 @@ class TestOpenSearchHook:
hook = OpenSearchHook(open_search_conn_id="opensearch_default",
log_query=True)
with pytest.raises(AirflowException, match="must include one of either
a query or a document id"):
hook.delete(index_name="test_index")
+
+ @mock.patch("airflow.hooks.base.BaseHook.get_connection")
+ def test_hook_param_bool(self, mock_get_connection):
+ mock_conn = Connection(
+ conn_id="opensearch_default", extra={"use_ssl": "True",
"verify_certs": "True"}
+ )
+ mock_get_connection.return_value = mock_conn
+ hook = OpenSearchHook(open_search_conn_id="opensearch_default",
log_query=True)
+
+ assert isinstance(hook.use_ssl, bool)
+ assert isinstance(hook.verify_certs, bool)
+
+ def test_load_conn_param(self, mock_hook):
+ hook_default =
OpenSearchHook(open_search_conn_id="opensearch_default", log_query=True)
+ assert hook_default.connection_class == DEFAULT_CONN
+
+ hook_Urllib3 = OpenSearchHook(
+ open_search_conn_id="opensearch_default",
+ log_query=True,
+ open_search_conn_class=Urllib3HttpConnection,
+ )
+ assert hook_Urllib3.connection_class == Urllib3HttpConnection