This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a057b950d7e Implemented cursor for ElasticsearchSQLHook so it can be 
used through SQLExecuteQueryOperator (#46439)
a057b950d7e is described below

commit a057b950d7ecb366ff123bfb0d465823bb1066e5
Author: David Blain <[email protected]>
AuthorDate: Wed Feb 5 20:48:14 2025 +0100

    Implemented cursor for ElasticsearchSQLHook so it can be used through 
SQLExecuteQueryOperator (#46439)
    
    
    
    ---------
    
    Co-authored-by: David Blain <[email protected]>
---
 .../providers/elasticsearch/hooks/elasticsearch.py |  93 +++++++++++--
 .../elasticsearch/hooks/test_elasticsearch.py      | 154 ++++++++++++++++-----
 2 files changed, 207 insertions(+), 40 deletions(-)

diff --git 
a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py
 
b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py
index ab1bc433d94..582e4abdb9e 100644
--- 
a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py
+++ 
b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+from collections.abc import Iterable, Mapping
 from functools import cached_property
 from typing import TYPE_CHECKING, Any
 from urllib import parse
@@ -42,6 +43,73 @@ def connect(
     return ESConnection(host, port, user, password, scheme, **kwargs)
 
 
+class ElasticsearchSQLCursor:
+    """A PEP 249-like Cursor class for Elasticsearch SQL API"""
+
+    def __init__(self, es: Elasticsearch, **kwargs):
+        self.es = es
+        self.body = {
+            "fetch_size": kwargs.get("fetch_size", 1000),
+            "field_multi_value_leniency": 
kwargs.get("field_multi_value_leniency", False),
+        }
+        self._response: ObjectApiResponse | None = None
+
+    @property
+    def response(self) -> ObjectApiResponse:
+        return self._response or {}  # type: ignore
+
+    @response.setter
+    def response(self, value):
+        self._response = value
+
+    @property
+    def cursor(self):
+        return self.response.get("cursor")
+
+    @property
+    def rows(self):
+        return self.response.get("rows", [])
+
+    @property
+    def rowcount(self) -> int:
+        return len(self.rows)
+
+    @property
+    def description(self) -> list[tuple]:
+        return [(column["name"], column["type"]) for column in 
self.response.get("columns", [])]
+
+    def execute(
+        self, statement: str, params: Iterable | Mapping[str, Any] | None = 
None
+    ) -> ObjectApiResponse:
+        self.body["query"] = statement
+        if params:
+            self.body["params"] = params
+        self.response = self.es.sql.query(body=self.body)
+        if self.cursor:
+            self.body["cursor"] = self.cursor
+        else:
+            self.body.pop("cursor", None)
+        return self.response
+
+    def fetchone(self):
+        if self.rows:
+            return self.rows[0]
+        return None
+
+    def fetchmany(self, size: int | None = None):
+        raise NotImplementedError()
+
+    def fetchall(self):
+        results = self.rows
+        while self.cursor:
+            self.execute(statement=self.body["query"])
+            results.extend(self.rows)
+        return results
+
+    def close(self):
+        self._response = None
+
+
 class ESConnection:
     """wrapper class for elasticsearch.Elasticsearch."""
 
@@ -67,9 +135,19 @@ class ESConnection:
         else:
             self.es = Elasticsearch(self.url, **self.kwargs)
 
-    def execute_sql(self, query: str) -> ObjectApiResponse:
-        sql_query = {"query": query}
-        return self.es.sql.query(body=sql_query)
+    def cursor(self) -> ElasticsearchSQLCursor:
+        return ElasticsearchSQLCursor(self.es, **self.kwargs)
+
+    def close(self):
+        self.es.close()
+
+    def commit(self):
+        pass
+
+    def execute_sql(
+        self, query: str, params: Iterable | Mapping[str, Any] | None = None
+    ) -> ObjectApiResponse:
+        return self.cursor().execute(query, params)
 
 
 class ElasticsearchSQLHook(DbApiHook):
@@ -84,13 +162,13 @@ class ElasticsearchSQLHook(DbApiHook):
 
     conn_name_attr = "elasticsearch_conn_id"
     default_conn_name = "elasticsearch_default"
+    connector = ESConnection
     conn_type = "elasticsearch"
     hook_name = "Elasticsearch"
 
     def __init__(self, schema: str = "http", connection: AirflowConnection | 
None = None, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self.schema = schema
-        self._connection = connection
 
     def get_conn(self) -> ESConnection:
         """Return an elasticsearch connection object."""
@@ -104,11 +182,10 @@ class ElasticsearchSQLHook(DbApiHook):
             "scheme": conn.schema or "http",
         }
 
-        if conn.extra_dejson.get("http_compress", False):
-            conn_args["http_compress"] = bool(["http_compress"])
+        conn_args.update(conn.extra_dejson)
 
-        if conn.extra_dejson.get("timeout", False):
-            conn_args["timeout"] = conn.extra_dejson["timeout"]
+        if conn_args.get("http_compress", False):
+            conn_args["http_compress"] = bool(conn_args["http_compress"])
 
         return connect(**conn_args)
 
diff --git 
a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py
 
b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py
index ea34f2532de..953e7dd50ef 100644
--- 
a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py
+++ 
b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py
@@ -20,15 +20,42 @@ from __future__ import annotations
 from unittest import mock
 from unittest.mock import MagicMock
 
+import pytest
 from elasticsearch import Elasticsearch
+from elasticsearch._sync.client import SqlClient
+from kgb import SpyAgency
 
 from airflow.models import Connection
+from airflow.providers.common.sql.hooks.handlers import fetch_all_handler
 from airflow.providers.elasticsearch.hooks.elasticsearch import (
     ElasticsearchPythonHook,
+    ElasticsearchSQLCursor,
     ElasticsearchSQLHook,
     ESConnection,
 )
 
+ROWS = [
+    [1, "Stallone", "Sylvester", "78"],
+    [2, "Statham", "Jason", "57"],
+    [3, "Li", "Jet", "61"],
+    [4, "Lundgren", "Dolph", "66"],
+    [5, "Norris", "Chuck", "84"],
+]
+RESPONSE_WITHOUT_CURSOR = {
+    "columns": [
+        {"name": "index", "type": "long"},
+        {"name": "name", "type": "text"},
+        {"name": "firstname", "type": "text"},
+        {"name": "age", "type": "long"},
+    ],
+    "rows": ROWS,
+}
+RESPONSE = {**RESPONSE_WITHOUT_CURSOR, **{"cursor": 
"e7f8QwXUruW2mIebzudH4BwAA//8DAA=="}}
+RESPONSES = [
+    RESPONSE,
+    RESPONSE_WITHOUT_CURSOR,
+]
+
 
 class TestElasticsearchSQLHookConn:
     def setup_method(self):
@@ -48,10 +75,68 @@ class TestElasticsearchSQLHookConn:
         mock_connect.assert_called_with(host="localhost", port=9200, 
scheme="http", user=None, password=None)
 
 
+class TestElasticsearchSQLCursor:
+    def setup_method(self):
+        sql = MagicMock(spec=SqlClient)
+        sql.query.side_effect = RESPONSES
+        self.es = MagicMock(sql=sql, spec=Elasticsearch)
+
+    def test_execute(self):
+        cursor = ElasticsearchSQLCursor(es=self.es, options={})
+
+        assert cursor.execute("SELECT * FROM hollywood.actors") == RESPONSE
+
+    def test_rowcount(self):
+        cursor = ElasticsearchSQLCursor(es=self.es, options={})
+        cursor.execute("SELECT * FROM hollywood.actors")
+
+        assert cursor.rowcount == len(ROWS)
+
+    def test_description(self):
+        cursor = ElasticsearchSQLCursor(es=self.es, options={})
+        cursor.execute("SELECT * FROM hollywood.actors")
+
+        assert cursor.description == [
+            ("index", "long"),
+            ("name", "text"),
+            ("firstname", "text"),
+            ("age", "long"),
+        ]
+
+    def test_fetchone(self):
+        cursor = ElasticsearchSQLCursor(es=self.es, options={})
+        cursor.execute("SELECT * FROM hollywood.actors")
+
+        assert cursor.fetchone() == ROWS[0]
+
+    def test_fetchmany(self):
+        cursor = ElasticsearchSQLCursor(es=self.es, options={})
+        cursor.execute("SELECT * FROM hollywood.actors")
+
+        with pytest.raises(NotImplementedError):
+            cursor.fetchmany()
+
+    def test_fetchall(self):
+        cursor = ElasticsearchSQLCursor(es=self.es, options={})
+        cursor.execute("SELECT * FROM hollywood.actors")
+
+        records = cursor.fetchall()
+
+        assert len(records) == 10
+        assert records == ROWS
+
+
 class TestElasticsearchSQLHook:
     def setup_method(self):
-        self.cur = mock.MagicMock(rowcount=0)
-        self.conn = mock.MagicMock()
+        sql = MagicMock(spec=SqlClient)
+        sql.query.side_effect = RESPONSES
+        es = MagicMock(sql=sql, spec=Elasticsearch)
+        self.cur = ElasticsearchSQLCursor(es=es, options={})
+        self.spy_agency = SpyAgency()
+        self.spy_agency.spy_on(self.cur.close, call_original=True)
+        self.spy_agency.spy_on(self.cur.execute, call_original=True)
+        self.spy_agency.spy_on(self.cur.fetchall, call_original=True)
+        self.conn = MagicMock(spec=ESConnection)
         self.conn.cursor.return_value = self.cur
         conn = self.conn
 
@@ -64,55 +149,60 @@ class TestElasticsearchSQLHook:
         self.db_hook = UnitTestElasticsearchSQLHook()
 
     def test_get_first_record(self):
-        statement = "SQL"
-        result_sets = [("row1",), ("row2",)]
-        self.cur.fetchone.return_value = result_sets[0]
+        statement = "SELECT * FROM hollywood.actors"
+
+        assert self.db_hook.get_first(statement) == ROWS[0]
 
-        assert result_sets[0] == self.db_hook.get_first(statement)
         self.conn.close.assert_called_once_with()
-        self.cur.close.assert_called_once_with()
-        self.cur.execute.assert_called_once_with(statement)
+        self.spy_agency.assert_spy_called(self.cur.close)
+        self.spy_agency.assert_spy_called(self.cur.execute)
 
     def test_get_records(self):
-        statement = "SQL"
-        result_sets = [("row1",), ("row2",)]
-        self.cur.fetchall.return_value = result_sets
+        statement = "SELECT * FROM hollywood.actors"
+
+        assert self.db_hook.get_records(statement) == ROWS
 
-        assert result_sets == self.db_hook.get_records(statement)
         self.conn.close.assert_called_once_with()
-        self.cur.close.assert_called_once_with()
-        self.cur.execute.assert_called_once_with(statement)
+        self.spy_agency.assert_spy_called(self.cur.close)
+        self.spy_agency.assert_spy_called(self.cur.execute)
 
     def test_get_pandas_df(self):
-        statement = "SQL"
-        column = "col"
-        result_sets = [("row1",), ("row2",)]
-        self.cur.description = [(column,)]
-        self.cur.fetchall.return_value = result_sets
+        statement = "SELECT * FROM hollywood.actors"
         df = self.db_hook.get_pandas_df(statement)
 
-        assert column == df.columns[0]
+        assert list(df.columns) == ["index", "name", "firstname", "age"]
+        assert df.values.tolist() == ROWS
+
+        self.conn.close.assert_called_once_with()
+        self.spy_agency.assert_spy_called(self.cur.close)
+        self.spy_agency.assert_spy_called(self.cur.execute)
+
+    def test_run(self):
+        statement = "SELECT * FROM hollywood.actors"
 
-        assert result_sets[0][0] == df.values.tolist()[0][0]
-        assert result_sets[1][0] == df.values.tolist()[1][0]
+        assert self.db_hook.run(statement, handler=fetch_all_handler) == ROWS
 
-        self.cur.execute.assert_called_once_with(statement)
+        self.conn.close.assert_called_once_with()
+        self.spy_agency.assert_spy_called(self.cur.close)
+        self.spy_agency.assert_spy_called(self.cur.execute)
 
     
@mock.patch("airflow.providers.elasticsearch.hooks.elasticsearch.Elasticsearch")
     def test_execute_sql_query(self, mock_es):
         mock_es_sql_client = MagicMock()
-        mock_es_sql_client.query.return_value = {
-            "columns": [{"name": "id"}, {"name": "first_name"}],
-            "rows": [[1, "John"], [2, "Jane"]],
-        }
+        mock_es_sql_client.query.return_value = RESPONSE_WITHOUT_CURSOR
         mock_es.return_value.sql = mock_es_sql_client
 
         es_connection = ESConnection(host="localhost", port=9200)
-        response = es_connection.execute_sql("SELECT * FROM index1")
-        mock_es_sql_client.query.assert_called_once_with(body={"query": 
"SELECT * FROM index1"})
-
-        assert response["rows"] == [[1, "John"], [2, "Jane"]]
-        assert response["columns"] == [{"name": "id"}, {"name": "first_name"}]
+        response = es_connection.execute_sql("SELECT * FROM hollywood.actors")
+        mock_es_sql_client.query.assert_called_once_with(
+            body={
+                "fetch_size": 1000,
+                "field_multi_value_leniency": False,
+                "query": "SELECT * FROM hollywood.actors",
+            }
+        )
+
+        assert response == RESPONSE_WITHOUT_CURSOR
 
 
 class MockElasticsearch:

Reply via email to