This is an automated email from the ASF dual-hosted git repository.

dimberman pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v1-10-test by this push:
     new e8ec9a0  DbApiHook: Support kwargs in get_pandas_df (#9730)
e8ec9a0 is described below

commit e8ec9a0b0a79a1e57ac3b7ba10d1ef47d7a01079
Author: zikun <[email protected]>
AuthorDate: Wed Aug 12 17:09:27 2020 +0800

    DbApiHook: Support kwargs in get_pandas_df (#9730)
    
    * DbApiHook: Support kwargs in get_pandas_df
    * BigQueryHook: Support kwargs in get_pandas_df
    * PrestoHook: Support kwargs in get_pandas_df
    * HiveServer2Hook: Support kwargs in get_pandas_df
    
    (cherry picked from commit 8f8db8959e526be54d700845d36ee9f315bae2ea)
---
 airflow/contrib/hooks/bigquery_hook.py | 7 +++++--
 airflow/hooks/dbapi_hook.py            | 8 +++++---
 airflow/hooks/hive_hooks.py            | 6 ++++--
 airflow/hooks/presto_hook.py           | 6 +++---
 4 files changed, 17 insertions(+), 10 deletions(-)

diff --git a/airflow/contrib/hooks/bigquery_hook.py 
b/airflow/contrib/hooks/bigquery_hook.py
index 07a2ab8..e99aa73 100644
--- a/airflow/contrib/hooks/bigquery_hook.py
+++ b/airflow/contrib/hooks/bigquery_hook.py
@@ -93,7 +93,7 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
         """
         raise NotImplementedError()
 
-    def get_pandas_df(self, sql, parameters=None, dialect=None):
+    def get_pandas_df(self, sql, parameters=None, dialect=None, **kwargs):
         """
         Returns a Pandas DataFrame for the results produced by a BigQuery
         query. The DbApiHook method must be overridden because Pandas
@@ -110,6 +110,8 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
         :param dialect: Dialect of BigQuery SQL – legacy SQL or standard SQL
             defaults to use `self.use_legacy_sql` if not specified
         :type dialect: str in {'legacy', 'standard'}
+        :param kwargs: (optional) passed into pandas_gbq.read_gbq method
+        :type kwargs: dict
         """
         private_key = self._get_field('key_path', None) or 
self._get_field('keyfile_dict', None)
 
@@ -120,7 +122,8 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
                         project_id=self._get_field('project'),
                         dialect=dialect,
                         verbose=False,
-                        private_key=private_key)
+                        private_key=private_key,
+                        **kwargs)
 
     def table_exists(self, project_id, dataset_id, table_id):
         """
diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py
index ac54881..76f4f0a 100644
--- a/airflow/hooks/dbapi_hook.py
+++ b/airflow/hooks/dbapi_hook.py
@@ -82,7 +82,7 @@ class DbApiHook(BaseHook):
             engine_kwargs = {}
         return create_engine(self.get_uri(), **engine_kwargs)
 
-    def get_pandas_df(self, sql, parameters=None):
+    def get_pandas_df(self, sql, parameters=None, **kwargs):
         """
         Executes the sql and returns a pandas dataframe
 
@@ -90,14 +90,16 @@ class DbApiHook(BaseHook):
             sql statements to execute
         :type sql: str or list
         :param parameters: The parameters to render the SQL query with.
-        :type parameters: mapping or iterable
+        :type parameters: dict or iterable
+        :param kwargs: (optional) passed into pandas.io.sql.read_sql method
+        :type kwargs: dict
         """
         if sys.version_info[0] < 3:
             sql = sql.encode('utf-8')
         import pandas.io.sql as psql
 
         with closing(self.get_conn()) as conn:
-            return psql.read_sql(sql, con=conn, params=parameters)
+            return psql.read_sql(sql, con=conn, params=parameters, **kwargs)
 
     def get_records(self, sql, parameters=None):
         """
diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py
index e521d7b..48def11 100644
--- a/airflow/hooks/hive_hooks.py
+++ b/airflow/hooks/hive_hooks.py
@@ -983,7 +983,7 @@ class HiveServer2Hook(BaseHook):
         """
         return self.get_results(hql, schema=schema)['data']
 
-    def get_pandas_df(self, hql, schema='default'):
+    def get_pandas_df(self, hql, schema='default', **kwargs):
         """
         Get a pandas dataframe from a Hive query
 
@@ -991,6 +991,8 @@ class HiveServer2Hook(BaseHook):
         :type hql: str or list
         :param schema: target schema, default to 'default'.
         :type schema: str
+        :param kwargs: (optional) passed into pandas.DataFrame constructor
+        :type kwargs: dict
         :return: result of hql execution
         :rtype: DataFrame
 
@@ -1004,6 +1006,6 @@ class HiveServer2Hook(BaseHook):
         """
         import pandas as pd
         res = self.get_results(hql, schema=schema)
-        df = pd.DataFrame(res['data'])
+        df = pd.DataFrame(res['data'], **kwargs)
         df.columns = [c[0] for c in res['header']]
         return df
diff --git a/airflow/hooks/presto_hook.py b/airflow/hooks/presto_hook.py
index 9788411..7d700ab 100644
--- a/airflow/hooks/presto_hook.py
+++ b/airflow/hooks/presto_hook.py
@@ -105,7 +105,7 @@ class PrestoHook(DbApiHook):
         except DatabaseError as e:
             raise PrestoException(self._get_pretty_exception_message(e))
 
-    def get_pandas_df(self, hql, parameters=None):
+    def get_pandas_df(self, hql, parameters=None, **kwargs):
         """
         Get a pandas dataframe from a sql query.
         """
@@ -118,10 +118,10 @@ class PrestoHook(DbApiHook):
             raise PrestoException(self._get_pretty_exception_message(e))
         column_descriptions = cursor.description
         if data:
-            df = pandas.DataFrame(data)
+            df = pandas.DataFrame(data, **kwargs)
             df.columns = [c[0] for c in column_descriptions]
         else:
-            df = pandas.DataFrame()
+            df = pandas.DataFrame(**kwargs)
         return df
 
     def run(self, hql, parameters=None):

Reply via email to