This is an automated email from the ASF dual-hosted git repository.
dimberman pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v1-10-test by this push:
new e8ec9a0 DbApiHook: Support kwargs in get_pandas_df (#9730)
e8ec9a0 is described below
commit e8ec9a0b0a79a1e57ac3b7ba10d1ef47d7a01079
Author: zikun <[email protected]>
AuthorDate: Wed Aug 12 17:09:27 2020 +0800
DbApiHook: Support kwargs in get_pandas_df (#9730)
* DbApiHook: Support kwargs in get_pandas_df
* BigQueryHook: Support kwargs in get_pandas_df
* PrestoHook: Support kwargs in get_pandas_df
* HiveServer2Hook: Support kwargs in get_pandas_df
(cherry picked from commit 8f8db8959e526be54d700845d36ee9f315bae2ea)
---
airflow/contrib/hooks/bigquery_hook.py | 7 +++++--
airflow/hooks/dbapi_hook.py | 8 +++++---
airflow/hooks/hive_hooks.py | 6 ++++--
airflow/hooks/presto_hook.py | 6 +++---
4 files changed, 17 insertions(+), 10 deletions(-)
diff --git a/airflow/contrib/hooks/bigquery_hook.py
b/airflow/contrib/hooks/bigquery_hook.py
index 07a2ab8..e99aa73 100644
--- a/airflow/contrib/hooks/bigquery_hook.py
+++ b/airflow/contrib/hooks/bigquery_hook.py
@@ -93,7 +93,7 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
"""
raise NotImplementedError()
- def get_pandas_df(self, sql, parameters=None, dialect=None):
+ def get_pandas_df(self, sql, parameters=None, dialect=None, **kwargs):
"""
Returns a Pandas DataFrame for the results produced by a BigQuery
query. The DbApiHook method must be overridden because Pandas
@@ -110,6 +110,8 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
:param dialect: Dialect of BigQuery SQL – legacy SQL or standard SQL
defaults to use `self.use_legacy_sql` if not specified
:type dialect: str in {'legacy', 'standard'}
+ :param kwargs: (optional) passed into pandas_gbq.read_gbq method
+ :type kwargs: dict
"""
private_key = self._get_field('key_path', None) or
self._get_field('keyfile_dict', None)
@@ -120,7 +122,8 @@ class BigQueryHook(GoogleCloudBaseHook, DbApiHook):
project_id=self._get_field('project'),
dialect=dialect,
verbose=False,
- private_key=private_key)
+ private_key=private_key,
+ **kwargs)
def table_exists(self, project_id, dataset_id, table_id):
"""
diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py
index ac54881..76f4f0a 100644
--- a/airflow/hooks/dbapi_hook.py
+++ b/airflow/hooks/dbapi_hook.py
@@ -82,7 +82,7 @@ class DbApiHook(BaseHook):
engine_kwargs = {}
return create_engine(self.get_uri(), **engine_kwargs)
- def get_pandas_df(self, sql, parameters=None):
+ def get_pandas_df(self, sql, parameters=None, **kwargs):
"""
Executes the sql and returns a pandas dataframe
@@ -90,14 +90,16 @@ class DbApiHook(BaseHook):
sql statements to execute
:type sql: str or list
:param parameters: The parameters to render the SQL query with.
- :type parameters: mapping or iterable
+ :type parameters: dict or iterable
+ :param kwargs: (optional) passed into pandas.io.sql.read_sql method
+ :type kwargs: dict
"""
if sys.version_info[0] < 3:
sql = sql.encode('utf-8')
import pandas.io.sql as psql
with closing(self.get_conn()) as conn:
- return psql.read_sql(sql, con=conn, params=parameters)
+ return psql.read_sql(sql, con=conn, params=parameters, **kwargs)
def get_records(self, sql, parameters=None):
"""
diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py
index e521d7b..48def11 100644
--- a/airflow/hooks/hive_hooks.py
+++ b/airflow/hooks/hive_hooks.py
@@ -983,7 +983,7 @@ class HiveServer2Hook(BaseHook):
"""
return self.get_results(hql, schema=schema)['data']
- def get_pandas_df(self, hql, schema='default'):
+ def get_pandas_df(self, hql, schema='default', **kwargs):
"""
Get a pandas dataframe from a Hive query
@@ -991,6 +991,8 @@ class HiveServer2Hook(BaseHook):
:type hql: str or list
:param schema: target schema, default to 'default'.
:type schema: str
+ :param kwargs: (optional) passed into pandas.DataFrame constructor
+ :type kwargs: dict
:return: result of hql execution
:rtype: DataFrame
@@ -1004,6 +1006,6 @@ class HiveServer2Hook(BaseHook):
"""
import pandas as pd
res = self.get_results(hql, schema=schema)
- df = pd.DataFrame(res['data'])
+ df = pd.DataFrame(res['data'], **kwargs)
df.columns = [c[0] for c in res['header']]
return df
diff --git a/airflow/hooks/presto_hook.py b/airflow/hooks/presto_hook.py
index 9788411..7d700ab 100644
--- a/airflow/hooks/presto_hook.py
+++ b/airflow/hooks/presto_hook.py
@@ -105,7 +105,7 @@ class PrestoHook(DbApiHook):
except DatabaseError as e:
raise PrestoException(self._get_pretty_exception_message(e))
- def get_pandas_df(self, hql, parameters=None):
+ def get_pandas_df(self, hql, parameters=None, **kwargs):
"""
Get a pandas dataframe from a sql query.
"""
@@ -118,10 +118,10 @@ class PrestoHook(DbApiHook):
raise PrestoException(self._get_pretty_exception_message(e))
column_descriptions = cursor.description
if data:
- df = pandas.DataFrame(data)
+ df = pandas.DataFrame(data, **kwargs)
df.columns = [c[0] for c in column_descriptions]
else:
- df = pandas.DataFrame()
+ df = pandas.DataFrame(**kwargs)
return df
def run(self, hql, parameters=None):