This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5f47e60962 Custom fetch all handler for vertica to not miss errors
(#34041)
5f47e60962 is described below
commit 5f47e60962b3123b1e6c8b42bef2c2643f54b601
Author: darkag <[email protected]>
AuthorDate: Wed Sep 6 23:09:53 2023 +0200
Custom fetch all handler for vertica to not miss errors (#34041)
* Custom fetch all handler for vertica to not miss errors
* missing parameter
* Fix test (set nextset to none)
* fix static checks
* fix static-check error
* fix static-check error
* rename variable
* add docstring
* fix docstring
---
airflow/providers/vertica/hooks/vertica.py | 80 ++++++++++++++++++++++++++-
tests/providers/vertica/hooks/test_vertica.py | 1 +
2 files changed, 78 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/vertica/hooks/vertica.py
b/airflow/providers/vertica/hooks/vertica.py
index 06b2e3cf17..91672e2aec 100644
--- a/airflow/providers/vertica/hooks/vertica.py
+++ b/airflow/providers/vertica/hooks/vertica.py
@@ -17,13 +17,45 @@
# under the License.
from __future__ import annotations
+from typing import Any, Callable, Iterable, Mapping, overload
+
from vertica_python import connect
-from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler
+
+
+def vertica_fetch_all_handler(cursor) -> list[tuple] | None:
+ """
+ Replace the default DbApiHook fetch_all_handler in order to fix this issue
https://github.com/apache/airflow/issues/32993.
+
+ Returned value will not change after the initial call of
fetch_all_handler, all the remaining code is here
+ only to make vertica client throws error.
+ With Vertica, if you run the following sql (with split_statements set to
false):
+
+ INSERT INTO MyTable (Key, Label) values (1, 'test 1');
+ INSERT INTO MyTable (Key, Label) values (1, 'test 2');
+ INSERT INTO MyTable (Key, Label) values (3, 'test 3');
+
+ each insert will have its own result set and if you don't try to fetch
data of those result sets
+ you won't detect error on the second insert.
+ """
+ result = fetch_all_handler(cursor)
+ # loop on all statement result sets to get errors
+ if cursor.description is not None:
+ while cursor.nextset():
+ if cursor.description is not None:
+ row = cursor.fetchone()
+ while row:
+ row = cursor.fetchone()
+ return result
class VerticaHook(DbApiHook):
- """Interact with Vertica."""
+ """
+ Interact with Vertica.
+
+ This hook use a customized version of default fetch_all_handler named
vertica_fetch_all_handler.
+ """
conn_name_attr = "vertica_conn_id"
default_conn_name = "vertica_default"
@@ -32,7 +64,7 @@ class VerticaHook(DbApiHook):
supports_autocommit = True
def get_conn(self) -> connect:
- """Return verticaql connection object."""
+ """Return vertica connection object."""
conn = self.get_connection(self.vertica_conn_id) # type: ignore
conn_config = {
"user": conn.login,
@@ -99,3 +131,45 @@ class VerticaHook(DbApiHook):
conn = connect(**conn_config)
return conn
+
+ @overload
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = ...,
+ parameters: Iterable | Mapping[str, Any] | None = ...,
+ handler: None = ...,
+ split_statements: bool = ...,
+ return_last: bool = ...,
+ ) -> None:
+ ...
+
+ @overload
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = ...,
+ parameters: Iterable | Mapping[str, Any] | None = ...,
+ handler: Callable[[Any], Any] = ...,
+ split_statements: bool = ...,
+ return_last: bool = ...,
+ ) -> Any | list[Any]:
+ ...
+
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = False,
+ parameters: Iterable | Mapping | None = None,
+ handler: Callable[[Any], Any] | None = None,
+ split_statements: bool = False,
+ return_last: bool = True,
+ ) -> Any | list[Any] | None:
+ """
+ Overwrite the common sql run.
+
+ Will automatically replace fetch_all_handler by
vertica_fetch_all_handler.
+ """
+ if handler == fetch_all_handler:
+ handler = vertica_fetch_all_handler
+ return DbApiHook.run(self, sql, autocommit, parameters, handler,
split_statements, return_last)
diff --git a/tests/providers/vertica/hooks/test_vertica.py
b/tests/providers/vertica/hooks/test_vertica.py
index 146c3bcd11..e5ff2538eb 100644
--- a/tests/providers/vertica/hooks/test_vertica.py
+++ b/tests/providers/vertica/hooks/test_vertica.py
@@ -127,6 +127,7 @@ class TestVerticaHookConn:
class TestVerticaHook:
def setup_method(self):
self.cur = mock.MagicMock(rowcount=0)
+ self.cur.nextset.side_effect = [None]
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn