This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1884f22  Pass Trino hook params to DbApiHook (#21479)
1884f22 is described below

commit 1884f2227d1e41d7bb37246ece4da5d871036c1f
Author: Dmytro Kazanzhy <[email protected]>
AuthorDate: Tue Feb 15 22:42:06 2022 +0200

    Pass Trino hook params to DbApiHook (#21479)
---
 airflow/providers/trino/hooks/trino.py    | 20 ++++++++++++++------
 tests/providers/trino/hooks/test_trino.py | 14 ++++++++++++--
 2 files changed, 26 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/trino/hooks/trino.py 
b/airflow/providers/trino/hooks/trino.py
index 9b7a95b..4ec6f30 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -16,7 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import os
-from typing import Any, Iterable, Optional
+from typing import Any, Callable, Iterable, Optional
 
 import trino
 from trino.exceptions import DatabaseError
@@ -106,7 +106,7 @@ class TrinoHook(DbApiHook):
     def _strip_sql(sql: str) -> str:
         return sql.strip().rstrip(';')
 
-    def get_records(self, hql, parameters: Optional[dict] = None):
+    def get_records(self, hql: str, parameters: Optional[dict] = None):
         """Get a set of records from Trino"""
         try:
             return super().get_records(self._strip_sql(hql), parameters)
@@ -120,7 +120,7 @@ class TrinoHook(DbApiHook):
         except DatabaseError as e:
             raise TrinoException(e)
 
-    def get_pandas_df(self, hql, parameters=None, **kwargs):
+    def get_pandas_df(self, hql: str, parameters: Optional[dict] = None, 
**kwargs):  # type: ignore[override]
         """Get a pandas dataframe from a sql query."""
         import pandas
 
@@ -138,9 +138,17 @@ class TrinoHook(DbApiHook):
             df = pandas.DataFrame(**kwargs)
         return df
 
-    def run(self, hql, autocommit: bool = False, parameters: Optional[dict] = 
None, handler=None) -> None:
+    def run(
+        self,
+        hql: str,
+        autocommit: bool = False,
+        parameters: Optional[dict] = None,
+        handler: Optional[Callable] = None,
+    ) -> None:
         """Execute the statement against Trino. Can be used to create views."""
-        return super().run(sql=self._strip_sql(hql), parameters=parameters)
+        return super().run(
+            sql=self._strip_sql(hql), autocommit=autocommit, 
parameters=parameters, handler=handler
+        )
 
     def insert_rows(
         self,
@@ -169,4 +177,4 @@ class TrinoHook(DbApiHook):
             )
             commit_every = 0
 
-        super().insert_rows(table, rows, target_fields, commit_every)
+        super().insert_rows(table, rows, target_fields, commit_every, replace)
diff --git a/tests/providers/trino/hooks/test_trino.py 
b/tests/providers/trino/hooks/test_trino.py
index f17d741..09d5aa5 100644
--- a/tests/providers/trino/hooks/test_trino.py
+++ b/tests/providers/trino/hooks/test_trino.py
@@ -149,8 +149,9 @@ class TestTrinoHook(unittest.TestCase):
         rows = [("hello",), ("world",)]
         target_fields = None
         commit_every = 10
-        self.db_hook.insert_rows(table, rows, target_fields, commit_every)
-        mock_insert_rows.assert_called_once_with(table, rows, None, 10)
+        replace = True
+        self.db_hook.insert_rows(table, rows, target_fields, commit_every, 
replace)
+        mock_insert_rows.assert_called_once_with(table, rows, None, 10, True)
 
     def test_get_first_record(self):
         statement = 'SQL'
@@ -187,6 +188,15 @@ class TestTrinoHook(unittest.TestCase):
 
         self.cur.execute.assert_called_once_with(statement, None)
 
+    @patch('airflow.hooks.dbapi.DbApiHook.run')
+    def test_run(self, mock_run):
+        hql = "SELECT 1"
+        autocommit = False
+        parameters = {"hello": "world"}
+        handler = str
+        self.db_hook.run(hql, autocommit, parameters, handler)
+        mock_run.assert_called_once_with(sql=hql, autocommit=False, 
parameters=parameters, handler=str)
+
 
 class TestTrinoHookIntegration(unittest.TestCase):
     @pytest.mark.integration("trino")

Reply via email to