This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 36010f6d0e Fix: Implement support for `fetchone()` in the ODBCHook and
the Databricks SQL Hook (#36161)
36010f6d0e is described below
commit 36010f6d0e3231081dbae095baff5a5b5c5b34eb
Author: Joffrey Bienvenu <[email protected]>
AuthorDate: Mon Dec 11 08:02:11 2023 +0100
Fix: Implement support for `fetchone()` in the ODBCHook and the Databricks
SQL Hook (#36161)
* feat: Implement fetchone() support for pyodbc.Row
* feat: Implement fetchone() support for databricks.sql.Row
* fix: improve docstring
---
.../providers/databricks/hooks/databricks_sql.py | 7 ++-
airflow/providers/odbc/hooks/odbc.py | 14 ++++--
.../databricks/hooks/test_databricks_sql.py | 11 +++++
tests/providers/odbc/hooks/test_odbc.py | 56 ++++++++++++++++++----
4 files changed, 73 insertions(+), 15 deletions(-)
diff --git a/airflow/providers/databricks/hooks/databricks_sql.py
b/airflow/providers/databricks/hooks/databricks_sql.py
index d61d9f1bd7..dc728c5ed7 100644
--- a/airflow/providers/databricks/hooks/databricks_sql.py
+++ b/airflow/providers/databricks/hooks/databricks_sql.py
@@ -21,6 +21,7 @@ from copy import copy
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar,
overload
from databricks import sql # type: ignore[attr-defined]
+from databricks.sql.types import Row
from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook,
return_single_query_results
@@ -242,9 +243,11 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
@staticmethod
def _make_serializable(result):
- """Transform the databricks Row objects into a JSON-serializable list
of rows."""
- if result is not None:
+ """Transform the databricks Row objects into JSON-serializable
lists."""
+ if isinstance(result, list):
return [list(row) for row in result]
+ elif isinstance(result, Row):
+ return list(result)
return result
def bulk_dump(self, table, tmp_file):
diff --git a/airflow/providers/odbc/hooks/odbc.py
b/airflow/providers/odbc/hooks/odbc.py
index 8242aa5247..d84933dc29 100644
--- a/airflow/providers/odbc/hooks/odbc.py
+++ b/airflow/providers/odbc/hooks/odbc.py
@@ -213,12 +213,16 @@ class OdbcHook(DbApiHook):
return cnx
@staticmethod
- def _make_serializable(result: list[pyodbc.Row] | None) ->
list[NamedTuple] | None:
+ def _make_serializable(result: list[pyodbc.Row] | pyodbc.Row | None) ->
list[NamedTuple] | None:
"""Transform the pyodbc.Row objects returned from an SQL command into
JSON-serializable NamedTuple."""
- if result is not None:
- columns: list[tuple[str, type]] = [col[:2] for col in
result[0].cursor_description]
- # Below line respects NamedTuple docstring, but mypy do not
support dynamically
- # instantiated Namedtuple, and will never do:
https://github.com/python/mypy/issues/848
+ # Below ignored lines respect NamedTuple docstring, but mypy do not
support dynamically
+ # instantiated Namedtuple, and will never do:
https://github.com/python/mypy/issues/848
+ columns: list[tuple[str, type]] | None = None
+ if isinstance(result, list):
+ columns = [col[:2] for col in result[0].cursor_description]
row_object = NamedTuple("Row", columns) # type: ignore[misc]
return [row_object(*row) for row in result]
+ elif isinstance(result, pyodbc.Row):
+ columns = [col[:2] for col in result.cursor_description]
+ return NamedTuple("Row", columns)(*result) # type: ignore[misc,
operator]
return result
diff --git a/tests/providers/databricks/hooks/test_databricks_sql.py
b/tests/providers/databricks/hooks/test_databricks_sql.py
index 1be035c443..64cd0b9c06 100644
--- a/tests/providers/databricks/hooks/test_databricks_sql.py
+++ b/tests/providers/databricks/hooks/test_databricks_sql.py
@@ -172,6 +172,17 @@ def get_cursor_descriptions(fields: list[str]) ->
list[tuple[str]]:
[[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
id="The return_last not set on multiple queries not set",
),
+ pytest.param(
+ True,
+ False,
+ "select * from test.test",
+ ["select * from test.test"],
+ [["id", "value"]],
+ (Row(id=1, value=2),),
+ [[("id",), ("value",)]],
+ [1, 2],
+ id="The return_last set and no split statements set on single
query in string",
+ ),
],
)
def test_query(
diff --git a/tests/providers/odbc/hooks/test_odbc.py
b/tests/providers/odbc/hooks/test_odbc.py
index 03e09a8adf..3740a5654c 100644
--- a/tests/providers/odbc/hooks/test_odbc.py
+++ b/tests/providers/odbc/hooks/test_odbc.py
@@ -32,9 +32,12 @@ from airflow.providers.odbc.hooks.odbc import OdbcHook
@pytest.fixture
-def mock_row():
- """
- Mock a pyodbc.Row object - This is a C object that can only be created
from C API of pyodbc.
+def pyodbc_row_mock():
+ """Mock a pyodbc.Row instantiated object.
+
+ This object is used in the tests to replace the real pyodbc.Row object.
+ pyodbc.Row is a C object that can only be created from C API of pyodbc.
+
This mock implements the two features used by the hook:
- cursor_description: which return column names and type
- __iter__: which allows exploding a row instance (*row)
@@ -59,6 +62,20 @@ def mock_row():
return Row
[email protected]
+def pyodbc_instancecheck():
+ """Mock a pyodbc.Row class which returns True to any isinstance()
checks."""
+
+ class PyodbcRowMeta(type):
+ def __instancecheck__(self, instance):
+ return True
+
+ class PyodbcRow(metaclass=PyodbcRowMeta):
+ pass
+
+ return PyodbcRow
+
+
class TestOdbcHook:
def get_hook(self=None, hook_params=None, conn_params=None):
hook_params = hook_params or {}
@@ -282,14 +299,18 @@ class TestOdbcHook:
def test_pyodbc_mock(self):
"""Ensure that pyodbc.Row object has a `cursor_description` method.
- In subsequent tests, pyodbc.Row is replaced by pure Python mock
object, which implements the above
- method. We want to detect any breaking change in the pyodbc object. If
it fails, the 'mock_row'
- needs to be updated.
+ In subsequent tests, pyodbc.Row is replaced by the 'pyodbc_row_mock'
fixture, which implements the
+ `cursor_description` method. We want to detect any breaking change in
the pyodbc object. If this test
+ fails, the 'pyodbc_row_mock' fixture needs to be updated.
"""
assert hasattr(pyodbc.Row, "cursor_description")
- def test_query_return_serializable_result(self, mock_row):
- pyodbc_result = [mock_row(key=1, column="value1"), mock_row(key=2,
column="value2")]
+ def test_query_return_serializable_result_with_fetchall(self,
pyodbc_row_mock):
+ """
+ Simulate a cursor.fetchall which returns an iterable of pyodbc.Row
object, and check if this iterable
+ get converted into a list of tuples.
+ """
+ pyodbc_result = [pyodbc_row_mock(key=1, column="value1"),
pyodbc_row_mock(key=2, column="value2")]
hook_result = [(1, "value1"), (2, "value2")]
def mock_handler(*_):
@@ -299,6 +320,25 @@ class TestOdbcHook:
result = hook.run("SQL", handler=mock_handler)
assert hook_result == result
+ def test_query_return_serializable_result_with_fetchone(
+ self, pyodbc_row_mock, monkeypatch, pyodbc_instancecheck
+ ):
+ """
+ Simulate a cursor.fetchone which returns one single pyodbc.Row object,
and check if this object gets
+ converted into a tuple.
+ """
+ pyodbc_result = pyodbc_row_mock(key=1, column="value1")
+ hook_result = (1, "value1")
+
+ def mock_handler(*_):
+ return pyodbc_result
+
+ hook = self.get_hook()
+ with monkeypatch.context() as patcher:
+ patcher.setattr("pyodbc.Row", pyodbc_instancecheck)
+ result = hook.run("SQL", handler=mock_handler)
+ assert hook_result == result
+
def test_query_no_handler_return_none(self):
hook = self.get_hook()
result = hook.run("SQL")