This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 36010f6d0e Fix: Implement support for `fetchone()` in the ODBCHook and 
the Databricks SQL Hook (#36161)
36010f6d0e is described below

commit 36010f6d0e3231081dbae095baff5a5b5c5b34eb
Author: Joffrey Bienvenu <[email protected]>
AuthorDate: Mon Dec 11 08:02:11 2023 +0100

    Fix: Implement support for `fetchone()` in the ODBCHook and the Databricks 
SQL Hook (#36161)
    
    * feat: Implement fetchone() support for pyodbc.Row
    
    * feat: Implement fetchone() support for databricks.sql.Row
    
    * fix: improve docstring
---
 .../providers/databricks/hooks/databricks_sql.py   |  7 ++-
 airflow/providers/odbc/hooks/odbc.py               | 14 ++++--
 .../databricks/hooks/test_databricks_sql.py        | 11 +++++
 tests/providers/odbc/hooks/test_odbc.py            | 56 ++++++++++++++++++----
 4 files changed, 73 insertions(+), 15 deletions(-)

diff --git a/airflow/providers/databricks/hooks/databricks_sql.py 
b/airflow/providers/databricks/hooks/databricks_sql.py
index d61d9f1bd7..dc728c5ed7 100644
--- a/airflow/providers/databricks/hooks/databricks_sql.py
+++ b/airflow/providers/databricks/hooks/databricks_sql.py
@@ -21,6 +21,7 @@ from copy import copy
 from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, 
overload
 
 from databricks import sql  # type: ignore[attr-defined]
+from databricks.sql.types import Row
 
 from airflow.exceptions import AirflowException
 from airflow.providers.common.sql.hooks.sql import DbApiHook, 
return_single_query_results
@@ -242,9 +243,11 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
 
     @staticmethod
     def _make_serializable(result):
-        """Transform the databricks Row objects into a JSON-serializable list 
of rows."""
-        if result is not None:
+        """Transform the databricks Row objects into JSON-serializable 
lists."""
+        if isinstance(result, list):
             return [list(row) for row in result]
+        elif isinstance(result, Row):
+            return list(result)
         return result
 
     def bulk_dump(self, table, tmp_file):
diff --git a/airflow/providers/odbc/hooks/odbc.py 
b/airflow/providers/odbc/hooks/odbc.py
index 8242aa5247..d84933dc29 100644
--- a/airflow/providers/odbc/hooks/odbc.py
+++ b/airflow/providers/odbc/hooks/odbc.py
@@ -213,12 +213,16 @@ class OdbcHook(DbApiHook):
         return cnx
 
     @staticmethod
-    def _make_serializable(result: list[pyodbc.Row] | None) -> 
list[NamedTuple] | None:
+    def _make_serializable(result: list[pyodbc.Row] | pyodbc.Row | None) -> 
list[NamedTuple] | None:
         """Transform the pyodbc.Row objects returned from an SQL command into 
JSON-serializable NamedTuple."""
-        if result is not None:
-            columns: list[tuple[str, type]] = [col[:2] for col in 
result[0].cursor_description]
-            # Below line respects NamedTuple docstring, but mypy do not 
support dynamically
-            # instantiated Namedtuple, and will never do: 
https://github.com/python/mypy/issues/848
+        # Below ignored lines respect NamedTuple docstring, but mypy do not 
support dynamically
+        # instantiated Namedtuple, and will never do: 
https://github.com/python/mypy/issues/848
+        columns: list[tuple[str, type]] | None = None
+        if isinstance(result, list):
+            columns = [col[:2] for col in result[0].cursor_description]
             row_object = NamedTuple("Row", columns)  # type: ignore[misc]
             return [row_object(*row) for row in result]
+        elif isinstance(result, pyodbc.Row):
+            columns = [col[:2] for col in result.cursor_description]
+            return NamedTuple("Row", columns)(*result)  # type: ignore[misc, 
operator]
         return result
diff --git a/tests/providers/databricks/hooks/test_databricks_sql.py 
b/tests/providers/databricks/hooks/test_databricks_sql.py
index 1be035c443..64cd0b9c06 100644
--- a/tests/providers/databricks/hooks/test_databricks_sql.py
+++ b/tests/providers/databricks/hooks/test_databricks_sql.py
@@ -172,6 +172,17 @@ def get_cursor_descriptions(fields: list[str]) -> 
list[tuple[str]]:
             [[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
             id="The return_last not set on multiple queries not set",
         ),
+        pytest.param(
+            True,
+            False,
+            "select * from test.test",
+            ["select * from test.test"],
+            [["id", "value"]],
+            (Row(id=1, value=2),),
+            [[("id",), ("value",)]],
+            [1, 2],
+            id="The return_last set and no split statements set on single 
query in string",
+        ),
     ],
 )
 def test_query(
diff --git a/tests/providers/odbc/hooks/test_odbc.py 
b/tests/providers/odbc/hooks/test_odbc.py
index 03e09a8adf..3740a5654c 100644
--- a/tests/providers/odbc/hooks/test_odbc.py
+++ b/tests/providers/odbc/hooks/test_odbc.py
@@ -32,9 +32,12 @@ from airflow.providers.odbc.hooks.odbc import OdbcHook
 
 
 @pytest.fixture
-def mock_row():
-    """
-    Mock a pyodbc.Row object - This is a C object that can only be created 
from C API of pyodbc.
+def pyodbc_row_mock():
+    """Mock a pyodbc.Row instantiated object.
+
+    This object is used in the tests to replace the real pyodbc.Row object.
+    pyodbc.Row is a C object that can only be created from C API of pyodbc.
+
     This mock implements the two features used by the hook:
         - cursor_description: which return column names and type
         - __iter__: which allows exploding a row instance (*row)
@@ -59,6 +62,20 @@ def mock_row():
     return Row
 
 
[email protected]
+def pyodbc_instancecheck():
+    """Mock a pyodbc.Row class which returns True to any isinstance() 
checks."""
+
+    class PyodbcRowMeta(type):
+        def __instancecheck__(self, instance):
+            return True
+
+    class PyodbcRow(metaclass=PyodbcRowMeta):
+        pass
+
+    return PyodbcRow
+
+
 class TestOdbcHook:
     def get_hook(self=None, hook_params=None, conn_params=None):
         hook_params = hook_params or {}
@@ -282,14 +299,18 @@ class TestOdbcHook:
     def test_pyodbc_mock(self):
         """Ensure that pyodbc.Row object has a `cursor_description` method.
 
-        In subsequent tests, pyodbc.Row is replaced by pure Python mock 
object, which implements the above
-        method. We want to detect any breaking change in the pyodbc object. If 
it fails, the 'mock_row'
-        needs to be updated.
+        In subsequent tests, pyodbc.Row is replaced by the 'pyodbc_row_mock' 
fixture, which implements the
+        `cursor_description` method. We want to detect any breaking change in 
the pyodbc object. If this test
+        fails, the 'pyodbc_row_mock' fixture needs to be updated.
         """
         assert hasattr(pyodbc.Row, "cursor_description")
 
-    def test_query_return_serializable_result(self, mock_row):
-        pyodbc_result = [mock_row(key=1, column="value1"), mock_row(key=2, 
column="value2")]
+    def test_query_return_serializable_result_with_fetchall(self, 
pyodbc_row_mock):
+        """
+        Simulate a cursor.fetchall which returns an iterable of pyodbc.Row 
object, and check if this iterable
+        get converted into a list of tuples.
+        """
+        pyodbc_result = [pyodbc_row_mock(key=1, column="value1"), 
pyodbc_row_mock(key=2, column="value2")]
         hook_result = [(1, "value1"), (2, "value2")]
 
         def mock_handler(*_):
@@ -299,6 +320,25 @@ class TestOdbcHook:
         result = hook.run("SQL", handler=mock_handler)
         assert hook_result == result
 
+    def test_query_return_serializable_result_with_fetchone(
+        self, pyodbc_row_mock, monkeypatch, pyodbc_instancecheck
+    ):
+        """
+        Simulate a cursor.fetchone which returns one single pyodbc.Row object, 
and check if this object gets
+        converted into a tuple.
+        """
+        pyodbc_result = pyodbc_row_mock(key=1, column="value1")
+        hook_result = (1, "value1")
+
+        def mock_handler(*_):
+            return pyodbc_result
+
+        hook = self.get_hook()
+        with monkeypatch.context() as patcher:
+            patcher.setattr("pyodbc.Row", pyodbc_instancecheck)
+            result = hook.run("SQL", handler=mock_handler)
+        assert hook_result == result
+
     def test_query_no_handler_return_none(self):
         hook = self.get_hook()
         result = hook.run("SQL")

Reply via email to