This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1884f22 Pass Trino hook params to DbApiHook (#21479)
1884f22 is described below
commit 1884f2227d1e41d7bb37246ece4da5d871036c1f
Author: Dmytro Kazanzhy <[email protected]>
AuthorDate: Tue Feb 15 22:42:06 2022 +0200
Pass Trino hook params to DbApiHook (#21479)
---
airflow/providers/trino/hooks/trino.py | 20 ++++++++++++++------
tests/providers/trino/hooks/test_trino.py | 14 ++++++++++++--
2 files changed, 26 insertions(+), 8 deletions(-)
diff --git a/airflow/providers/trino/hooks/trino.py
b/airflow/providers/trino/hooks/trino.py
index 9b7a95b..4ec6f30 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
import os
-from typing import Any, Iterable, Optional
+from typing import Any, Callable, Iterable, Optional
import trino
from trino.exceptions import DatabaseError
@@ -106,7 +106,7 @@ class TrinoHook(DbApiHook):
def _strip_sql(sql: str) -> str:
return sql.strip().rstrip(';')
- def get_records(self, hql, parameters: Optional[dict] = None):
+ def get_records(self, hql: str, parameters: Optional[dict] = None):
"""Get a set of records from Trino"""
try:
return super().get_records(self._strip_sql(hql), parameters)
@@ -120,7 +120,7 @@ class TrinoHook(DbApiHook):
except DatabaseError as e:
raise TrinoException(e)
- def get_pandas_df(self, hql, parameters=None, **kwargs):
+ def get_pandas_df(self, hql: str, parameters: Optional[dict] = None,
**kwargs): # type: ignore[override]
"""Get a pandas dataframe from a sql query."""
import pandas
@@ -138,9 +138,17 @@ class TrinoHook(DbApiHook):
df = pandas.DataFrame(**kwargs)
return df
- def run(self, hql, autocommit: bool = False, parameters: Optional[dict] =
None, handler=None) -> None:
+ def run(
+ self,
+ hql: str,
+ autocommit: bool = False,
+ parameters: Optional[dict] = None,
+ handler: Optional[Callable] = None,
+ ) -> None:
"""Execute the statement against Trino. Can be used to create views."""
- return super().run(sql=self._strip_sql(hql), parameters=parameters)
+ return super().run(
+ sql=self._strip_sql(hql), autocommit=autocommit,
parameters=parameters, handler=handler
+ )
def insert_rows(
self,
@@ -169,4 +177,4 @@ class TrinoHook(DbApiHook):
)
commit_every = 0
- super().insert_rows(table, rows, target_fields, commit_every)
+ super().insert_rows(table, rows, target_fields, commit_every, replace)
diff --git a/tests/providers/trino/hooks/test_trino.py
b/tests/providers/trino/hooks/test_trino.py
index f17d741..09d5aa5 100644
--- a/tests/providers/trino/hooks/test_trino.py
+++ b/tests/providers/trino/hooks/test_trino.py
@@ -149,8 +149,9 @@ class TestTrinoHook(unittest.TestCase):
rows = [("hello",), ("world",)]
target_fields = None
commit_every = 10
- self.db_hook.insert_rows(table, rows, target_fields, commit_every)
- mock_insert_rows.assert_called_once_with(table, rows, None, 10)
+ replace = True
+ self.db_hook.insert_rows(table, rows, target_fields, commit_every,
replace)
+ mock_insert_rows.assert_called_once_with(table, rows, None, 10, True)
def test_get_first_record(self):
statement = 'SQL'
@@ -187,6 +188,15 @@ class TestTrinoHook(unittest.TestCase):
self.cur.execute.assert_called_once_with(statement, None)
+ @patch('airflow.hooks.dbapi.DbApiHook.run')
+ def test_run(self, mock_run):
+ hql = "SELECT 1"
+ autocommit = False
+ parameters = {"hello": "world"}
+ handler = str
+ self.db_hook.run(hql, autocommit, parameters, handler)
+ mock_run.assert_called_once_with(sql=hql, autocommit=False,
parameters=parameters, handler=str)
+
class TestTrinoHookIntegration(unittest.TestCase):
@pytest.mark.integration("trino")