This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a057b950d7e Implemented cursor for ElasticsearchSQLHook so it can be
used through SQLExecuteQueryOperator (#46439)
a057b950d7e is described below
commit a057b950d7ecb366ff123bfb0d465823bb1066e5
Author: David Blain <[email protected]>
AuthorDate: Wed Feb 5 20:48:14 2025 +0100
Implemented cursor for ElasticsearchSQLHook so it can be used through
SQLExecuteQueryOperator (#46439)
---------
Co-authored-by: David Blain <[email protected]>
---
.../providers/elasticsearch/hooks/elasticsearch.py | 93 +++++++++++--
.../elasticsearch/hooks/test_elasticsearch.py | 154 ++++++++++++++++-----
2 files changed, 207 insertions(+), 40 deletions(-)
diff --git
a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py
b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py
index ab1bc433d94..582e4abdb9e 100644
---
a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py
+++
b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+from collections.abc import Iterable, Mapping
from functools import cached_property
from typing import TYPE_CHECKING, Any
from urllib import parse
@@ -42,6 +43,73 @@ def connect(
return ESConnection(host, port, user, password, scheme, **kwargs)
+class ElasticsearchSQLCursor:
+ """A PEP 249-like Cursor class for Elasticsearch SQL API"""
+
+ def __init__(self, es: Elasticsearch, **kwargs):
+ self.es = es
+ self.body = {
+ "fetch_size": kwargs.get("fetch_size", 1000),
+ "field_multi_value_leniency":
kwargs.get("field_multi_value_leniency", False),
+ }
+ self._response: ObjectApiResponse | None = None
+
+ @property
+ def response(self) -> ObjectApiResponse:
+ return self._response or {} # type: ignore
+
+ @response.setter
+ def response(self, value):
+ self._response = value
+
+ @property
+ def cursor(self):
+ return self.response.get("cursor")
+
+ @property
+ def rows(self):
+ return self.response.get("rows", [])
+
+ @property
+ def rowcount(self) -> int:
+ return len(self.rows)
+
+ @property
+ def description(self) -> list[tuple]:
+ return [(column["name"], column["type"]) for column in
self.response.get("columns", [])]
+
+ def execute(
+ self, statement: str, params: Iterable | Mapping[str, Any] | None =
None
+ ) -> ObjectApiResponse:
+ self.body["query"] = statement
+ if params:
+ self.body["params"] = params
+ self.response = self.es.sql.query(body=self.body)
+ if self.cursor:
+ self.body["cursor"] = self.cursor
+ else:
+ self.body.pop("cursor", None)
+ return self.response
+
+ def fetchone(self):
+ if self.rows:
+ return self.rows[0]
+ return None
+
+ def fetchmany(self, size: int | None = None):
+ raise NotImplementedError()
+
+ def fetchall(self):
+ results = self.rows
+ while self.cursor:
+ self.execute(statement=self.body["query"])
+ results.extend(self.rows)
+ return results
+
+ def close(self):
+ self._response = None
+
+
class ESConnection:
"""wrapper class for elasticsearch.Elasticsearch."""
@@ -67,9 +135,19 @@ class ESConnection:
else:
self.es = Elasticsearch(self.url, **self.kwargs)
- def execute_sql(self, query: str) -> ObjectApiResponse:
- sql_query = {"query": query}
- return self.es.sql.query(body=sql_query)
+ def cursor(self) -> ElasticsearchSQLCursor:
+ return ElasticsearchSQLCursor(self.es, **self.kwargs)
+
+ def close(self):
+ self.es.close()
+
+ def commit(self):
+ pass
+
+ def execute_sql(
+ self, query: str, params: Iterable | Mapping[str, Any] | None = None
+ ) -> ObjectApiResponse:
+ return self.cursor().execute(query, params)
class ElasticsearchSQLHook(DbApiHook):
@@ -84,13 +162,13 @@ class ElasticsearchSQLHook(DbApiHook):
conn_name_attr = "elasticsearch_conn_id"
default_conn_name = "elasticsearch_default"
+ connector = ESConnection
conn_type = "elasticsearch"
hook_name = "Elasticsearch"
def __init__(self, schema: str = "http", connection: AirflowConnection |
None = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.schema = schema
- self._connection = connection
def get_conn(self) -> ESConnection:
"""Return an elasticsearch connection object."""
@@ -104,11 +182,10 @@ class ElasticsearchSQLHook(DbApiHook):
"scheme": conn.schema or "http",
}
- if conn.extra_dejson.get("http_compress", False):
- conn_args["http_compress"] = bool(["http_compress"])
+ conn_args.update(conn.extra_dejson)
- if conn.extra_dejson.get("timeout", False):
- conn_args["timeout"] = conn.extra_dejson["timeout"]
+ if conn_args.get("http_compress", False):
+ conn_args["http_compress"] = bool(conn_args["http_compress"])
return connect(**conn_args)
diff --git
a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py
b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py
index ea34f2532de..953e7dd50ef 100644
---
a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py
+++
b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py
@@ -20,15 +20,42 @@ from __future__ import annotations
from unittest import mock
from unittest.mock import MagicMock
+import pytest
from elasticsearch import Elasticsearch
+from elasticsearch._sync.client import SqlClient
+from kgb import SpyAgency
from airflow.models import Connection
+from airflow.providers.common.sql.hooks.handlers import fetch_all_handler
from airflow.providers.elasticsearch.hooks.elasticsearch import (
ElasticsearchPythonHook,
+ ElasticsearchSQLCursor,
ElasticsearchSQLHook,
ESConnection,
)
+ROWS = [
+ [1, "Stallone", "Sylvester", "78"],
+ [2, "Statham", "Jason", "57"],
+ [3, "Li", "Jet", "61"],
+ [4, "Lundgren", "Dolph", "66"],
+ [5, "Norris", "Chuck", "84"],
+]
+RESPONSE_WITHOUT_CURSOR = {
+ "columns": [
+ {"name": "index", "type": "long"},
+ {"name": "name", "type": "text"},
+ {"name": "firstname", "type": "text"},
+ {"name": "age", "type": "long"},
+ ],
+ "rows": ROWS,
+}
+RESPONSE = {**RESPONSE_WITHOUT_CURSOR, **{"cursor":
"e7f8QwXUruW2mIebzudH4BwAA//8DAA=="}}
+RESPONSES = [
+ RESPONSE,
+ RESPONSE_WITHOUT_CURSOR,
+]
+
class TestElasticsearchSQLHookConn:
def setup_method(self):
@@ -48,10 +75,68 @@ class TestElasticsearchSQLHookConn:
mock_connect.assert_called_with(host="localhost", port=9200,
scheme="http", user=None, password=None)
+class TestElasticsearchSQLCursor:
+ def setup_method(self):
+ sql = MagicMock(spec=SqlClient)
+ sql.query.side_effect = RESPONSES
+ self.es = MagicMock(sql=sql, spec=Elasticsearch)
+
+ def test_execute(self):
+ cursor = ElasticsearchSQLCursor(es=self.es, options={})
+
+ assert cursor.execute("SELECT * FROM hollywood.actors") == RESPONSE
+
+ def test_rowcount(self):
+ cursor = ElasticsearchSQLCursor(es=self.es, options={})
+ cursor.execute("SELECT * FROM hollywood.actors")
+
+ assert cursor.rowcount == len(ROWS)
+
+ def test_description(self):
+ cursor = ElasticsearchSQLCursor(es=self.es, options={})
+ cursor.execute("SELECT * FROM hollywood.actors")
+
+ assert cursor.description == [
+ ("index", "long"),
+ ("name", "text"),
+ ("firstname", "text"),
+ ("age", "long"),
+ ]
+
+ def test_fetchone(self):
+ cursor = ElasticsearchSQLCursor(es=self.es, options={})
+ cursor.execute("SELECT * FROM hollywood.actors")
+
+ assert cursor.fetchone() == ROWS[0]
+
+ def test_fetchmany(self):
+ cursor = ElasticsearchSQLCursor(es=self.es, options={})
+ cursor.execute("SELECT * FROM hollywood.actors")
+
+ with pytest.raises(NotImplementedError):
+ cursor.fetchmany()
+
+ def test_fetchall(self):
+ cursor = ElasticsearchSQLCursor(es=self.es, options={})
+ cursor.execute("SELECT * FROM hollywood.actors")
+
+ records = cursor.fetchall()
+
+ assert len(records) == 10
+ assert records == ROWS
+
+
class TestElasticsearchSQLHook:
def setup_method(self):
- self.cur = mock.MagicMock(rowcount=0)
- self.conn = mock.MagicMock()
+ sql = MagicMock(spec=SqlClient)
+ sql.query.side_effect = RESPONSES
+ es = MagicMock(sql=sql, spec=Elasticsearch)
+ self.cur = ElasticsearchSQLCursor(es=es, options={})
+ self.spy_agency = SpyAgency()
+ self.spy_agency.spy_on(self.cur.close, call_original=True)
+ self.spy_agency.spy_on(self.cur.execute, call_original=True)
+ self.spy_agency.spy_on(self.cur.fetchall, call_original=True)
+ self.conn = MagicMock(spec=ESConnection)
self.conn.cursor.return_value = self.cur
conn = self.conn
@@ -64,55 +149,60 @@ class TestElasticsearchSQLHook:
self.db_hook = UnitTestElasticsearchSQLHook()
def test_get_first_record(self):
- statement = "SQL"
- result_sets = [("row1",), ("row2",)]
- self.cur.fetchone.return_value = result_sets[0]
+ statement = "SELECT * FROM hollywood.actors"
+
+ assert self.db_hook.get_first(statement) == ROWS[0]
- assert result_sets[0] == self.db_hook.get_first(statement)
self.conn.close.assert_called_once_with()
- self.cur.close.assert_called_once_with()
- self.cur.execute.assert_called_once_with(statement)
+ self.spy_agency.assert_spy_called(self.cur.close)
+ self.spy_agency.assert_spy_called(self.cur.execute)
def test_get_records(self):
- statement = "SQL"
- result_sets = [("row1",), ("row2",)]
- self.cur.fetchall.return_value = result_sets
+ statement = "SELECT * FROM hollywood.actors"
+
+ assert self.db_hook.get_records(statement) == ROWS
- assert result_sets == self.db_hook.get_records(statement)
self.conn.close.assert_called_once_with()
- self.cur.close.assert_called_once_with()
- self.cur.execute.assert_called_once_with(statement)
+ self.spy_agency.assert_spy_called(self.cur.close)
+ self.spy_agency.assert_spy_called(self.cur.execute)
def test_get_pandas_df(self):
- statement = "SQL"
- column = "col"
- result_sets = [("row1",), ("row2",)]
- self.cur.description = [(column,)]
- self.cur.fetchall.return_value = result_sets
+ statement = "SELECT * FROM hollywood.actors"
df = self.db_hook.get_pandas_df(statement)
- assert column == df.columns[0]
+ assert list(df.columns) == ["index", "name", "firstname", "age"]
+ assert df.values.tolist() == ROWS
+
+ self.conn.close.assert_called_once_with()
+ self.spy_agency.assert_spy_called(self.cur.close)
+ self.spy_agency.assert_spy_called(self.cur.execute)
+
+ def test_run(self):
+ statement = "SELECT * FROM hollywood.actors"
- assert result_sets[0][0] == df.values.tolist()[0][0]
- assert result_sets[1][0] == df.values.tolist()[1][0]
+ assert self.db_hook.run(statement, handler=fetch_all_handler) == ROWS
- self.cur.execute.assert_called_once_with(statement)
+ self.conn.close.assert_called_once_with()
+ self.spy_agency.assert_spy_called(self.cur.close)
+ self.spy_agency.assert_spy_called(self.cur.execute)
@mock.patch("airflow.providers.elasticsearch.hooks.elasticsearch.Elasticsearch")
def test_execute_sql_query(self, mock_es):
mock_es_sql_client = MagicMock()
- mock_es_sql_client.query.return_value = {
- "columns": [{"name": "id"}, {"name": "first_name"}],
- "rows": [[1, "John"], [2, "Jane"]],
- }
+ mock_es_sql_client.query.return_value = RESPONSE_WITHOUT_CURSOR
mock_es.return_value.sql = mock_es_sql_client
es_connection = ESConnection(host="localhost", port=9200)
- response = es_connection.execute_sql("SELECT * FROM index1")
- mock_es_sql_client.query.assert_called_once_with(body={"query":
"SELECT * FROM index1"})
-
- assert response["rows"] == [[1, "John"], [2, "Jane"]]
- assert response["columns"] == [{"name": "id"}, {"name": "first_name"}]
+ response = es_connection.execute_sql("SELECT * FROM hollywood.actors")
+ mock_es_sql_client.query.assert_called_once_with(
+ body={
+ "fetch_size": 1000,
+ "field_multi_value_leniency": False,
+ "query": "SELECT * FROM hollywood.actors",
+ }
+ )
+
+ assert response == RESPONSE_WITHOUT_CURSOR
class MockElasticsearch: