This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5f47e60962 Custom fetch all handler for vertica to not miss errors 
(#34041)
5f47e60962 is described below

commit 5f47e60962b3123b1e6c8b42bef2c2643f54b601
Author: darkag <[email protected]>
AuthorDate: Wed Sep 6 23:09:53 2023 +0200

    Custom fetch all handler for vertica to not miss errors (#34041)
    
    * Custom fetch all handler for vertica to not miss errors
    
    * missing parameter
    
    * Fix test (set nextset to none)
    
    * fix static checks
    
    * fix static-check error
    
    * fix static-check error
    
    * rename variable
    
    * add docstring
    
    * fix docstring
---
 airflow/providers/vertica/hooks/vertica.py    | 80 ++++++++++++++++++++++++++-
 tests/providers/vertica/hooks/test_vertica.py |  1 +
 2 files changed, 78 insertions(+), 3 deletions(-)

diff --git a/airflow/providers/vertica/hooks/vertica.py 
b/airflow/providers/vertica/hooks/vertica.py
index 06b2e3cf17..91672e2aec 100644
--- a/airflow/providers/vertica/hooks/vertica.py
+++ b/airflow/providers/vertica/hooks/vertica.py
@@ -17,13 +17,45 @@
 # under the License.
 from __future__ import annotations
 
+from typing import Any, Callable, Iterable, Mapping, overload
+
 from vertica_python import connect
 
-from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler
+
+
+def vertica_fetch_all_handler(cursor) -> list[tuple] | None:
+    """
+    Replace the default DbApiHook fetch_all_handler in order to fix this issue 
https://github.com/apache/airflow/issues/32993.
+
+    Returned value will not change after the initial call of 
fetch_all_handler, all the remaining code is here
+    only to make vertica client throws error.
+    With Vertica, if you run the following sql (with split_statements set to 
false):
+
+    INSERT INTO MyTable (Key, Label) values (1, 'test 1');
+    INSERT INTO MyTable (Key, Label) values (1, 'test 2');
+    INSERT INTO MyTable (Key, Label) values (3, 'test 3');
+
+    each insert will have its own result set and if you don't try to fetch 
data of those result sets
+    you won't detect error on the second insert.
+    """
+    result = fetch_all_handler(cursor)
+    # loop on all statement result sets to get errors
+    if cursor.description is not None:
+        while cursor.nextset():
+            if cursor.description is not None:
+                row = cursor.fetchone()
+                while row:
+                    row = cursor.fetchone()
+    return result
 
 
 class VerticaHook(DbApiHook):
-    """Interact with Vertica."""
+    """
+    Interact with Vertica.
+
+    This hook use a customized version of default fetch_all_handler named 
vertica_fetch_all_handler.
+    """
 
     conn_name_attr = "vertica_conn_id"
     default_conn_name = "vertica_default"
@@ -32,7 +64,7 @@ class VerticaHook(DbApiHook):
     supports_autocommit = True
 
     def get_conn(self) -> connect:
-        """Return verticaql connection object."""
+        """Return vertica connection object."""
         conn = self.get_connection(self.vertica_conn_id)  # type: ignore
         conn_config = {
             "user": conn.login,
@@ -99,3 +131,45 @@ class VerticaHook(DbApiHook):
 
         conn = connect(**conn_config)
         return conn
+
+    @overload
+    def run(
+        self,
+        sql: str | Iterable[str],
+        autocommit: bool = ...,
+        parameters: Iterable | Mapping[str, Any] | None = ...,
+        handler: None = ...,
+        split_statements: bool = ...,
+        return_last: bool = ...,
+    ) -> None:
+        ...
+
+    @overload
+    def run(
+        self,
+        sql: str | Iterable[str],
+        autocommit: bool = ...,
+        parameters: Iterable | Mapping[str, Any] | None = ...,
+        handler: Callable[[Any], Any] = ...,
+        split_statements: bool = ...,
+        return_last: bool = ...,
+    ) -> Any | list[Any]:
+        ...
+
+    def run(
+        self,
+        sql: str | Iterable[str],
+        autocommit: bool = False,
+        parameters: Iterable | Mapping | None = None,
+        handler: Callable[[Any], Any] | None = None,
+        split_statements: bool = False,
+        return_last: bool = True,
+    ) -> Any | list[Any] | None:
+        """
+        Overwrite the common sql run.
+
+        Will automatically replace fetch_all_handler by 
vertica_fetch_all_handler.
+        """
+        if handler == fetch_all_handler:
+            handler = vertica_fetch_all_handler
+        return DbApiHook.run(self, sql, autocommit, parameters, handler, 
split_statements, return_last)
diff --git a/tests/providers/vertica/hooks/test_vertica.py 
b/tests/providers/vertica/hooks/test_vertica.py
index 146c3bcd11..e5ff2538eb 100644
--- a/tests/providers/vertica/hooks/test_vertica.py
+++ b/tests/providers/vertica/hooks/test_vertica.py
@@ -127,6 +127,7 @@ class TestVerticaHookConn:
 class TestVerticaHook:
     def setup_method(self):
         self.cur = mock.MagicMock(rowcount=0)
+        self.cur.nextset.side_effect = [None]
         self.conn = mock.MagicMock()
         self.conn.cursor.return_value = self.cur
         conn = self.conn

Reply via email to