This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8914e49551 SqlToS3Operator: feat/ add max_rows_per_file parameter
(#37055)
8914e49551 is described below
commit 8914e49551d8ae5ece7418950b011c1f338b4634
Author: Selim CHERGUI <[email protected]>
AuthorDate: Tue Jan 30 01:27:46 2024 +0100
SqlToS3Operator: feat/ add max_rows_per_file parameter (#37055)
---------
Co-authored-by: Selim Chergui <[email protected]>
Co-authored-by: Jarek Potiuk <[email protected]>
---
.../providers/amazon/aws/transfers/sql_to_s3.py | 34 ++++++++++++--
.../amazon/aws/transfers/test_sql_to_s3.py | 52 ++++++++++++++++++++++
2 files changed, 83 insertions(+), 3 deletions(-)
diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py
b/airflow/providers/amazon/aws/transfers/sql_to_s3.py
index c00784ad4a..f8691fa4a2 100644
--- a/airflow/providers/amazon/aws/transfers/sql_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py
@@ -81,6 +81,9 @@ class SqlToS3Operator(BaseOperator):
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:param file_format: the destination file format, only string 'csv', 'json'
or 'parquet' is accepted.
+ :param max_rows_per_file: (optional) argument to set destination file
number of rows limit, if source data
+ is larger than that, it will be dispatched into multiple files.
+ Will be ignored if ``groupby_kwargs`` argument is specified.
:param pd_kwargs: arguments to include in DataFrame ``.to_parquet()``,
``.to_json()`` or ``.to_csv()``.
:param groupby_kwargs: argument to include in DataFrame ``groupby()``.
"""
@@ -110,6 +113,7 @@ class SqlToS3Operator(BaseOperator):
aws_conn_id: str = "aws_default",
verify: bool | str | None = None,
file_format: Literal["csv", "json", "parquet"] = "csv",
+ max_rows_per_file: int = 0,
pd_kwargs: dict | None = None,
groupby_kwargs: dict | None = None,
**kwargs,
@@ -124,12 +128,19 @@ class SqlToS3Operator(BaseOperator):
self.replace = replace
self.pd_kwargs = pd_kwargs or {}
self.parameters = parameters
+ self.max_rows_per_file = max_rows_per_file
self.groupby_kwargs = groupby_kwargs or {}
self.sql_hook_params = sql_hook_params
if "path_or_buf" in self.pd_kwargs:
raise AirflowException("The argument path_or_buf is not allowed,
please remove it")
+ if self.max_rows_per_file and self.groupby_kwargs:
+ raise AirflowException(
+ "SqlToS3Operator arguments max_rows_per_file and
groupby_kwargs "
+ "can not be both specified. Please choose one."
+ )
+
try:
self.file_format = FILE_FORMAT[file_format.upper()]
except KeyError:
@@ -177,10 +188,8 @@ class SqlToS3Operator(BaseOperator):
s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
data_df = sql_hook.get_pandas_df(sql=self.query,
parameters=self.parameters)
self.log.info("Data from SQL obtained")
-
self._fix_dtypes(data_df, self.file_format)
file_options = FILE_OPTIONS_MAP[self.file_format]
-
for group_name, df in self._partition_dataframe(df=data_df):
with NamedTemporaryFile(mode=file_options.mode,
suffix=file_options.suffix) as tmp_file:
self.log.info("Writing data to temp file")
@@ -194,13 +203,32 @@ class SqlToS3Operator(BaseOperator):
def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str,
pd.DataFrame]]:
"""Partition dataframe using pandas groupby() method."""
+ try:
+ import secrets
+ import string
+
+ import numpy as np
+ except ImportError:
+ pass
+ # if max_rows_per_file argument is specified, a temporary column with
a random unusual name will be
+ # added to the dataframe. This column is used to dispatch the
dataframe into smaller ones using groupby()
+ random_column_name = ""
+ if self.max_rows_per_file and not self.groupby_kwargs:
+ random_column_name = "".join(secrets.choice(string.ascii_letters)
for _ in range(20))
+ df[random_column_name] = np.arange(len(df)) //
self.max_rows_per_file
+ self.groupby_kwargs = {"by": random_column_name}
if not self.groupby_kwargs:
yield "", df
return
for group_label in (grouped_df :=
df.groupby(**self.groupby_kwargs)).groups:
yield (
cast(str, group_label),
- cast("pd.DataFrame",
grouped_df.get_group(group_label).reset_index(drop=True)),
+ cast(
+ "pd.DataFrame",
+ grouped_df.get_group(group_label)
+ .drop(random_column_name, axis=1, errors="ignore")
+ .reset_index(drop=True),
+ ),
)
def _get_hook(self) -> DbApiHook:
diff --git a/tests/providers/amazon/aws/transfers/test_sql_to_s3.py
b/tests/providers/amazon/aws/transfers/test_sql_to_s3.py
index cc56fd064a..feee688d46 100644
--- a/tests/providers/amazon/aws/transfers/test_sql_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_sql_to_s3.py
@@ -271,6 +271,58 @@ class TestSqlToS3Operator:
)
)
+ def test_with_max_rows_per_file(self):
+ """
+ Test operator when the max_rows_per_file is specified
+ """
+ query = "query"
+ s3_bucket = "bucket"
+ s3_key = "key"
+
+ op = SqlToS3Operator(
+ query=query,
+ s3_bucket=s3_bucket,
+ s3_key=s3_key,
+ sql_conn_id="mysql_conn_id",
+ aws_conn_id="aws_conn_id",
+ task_id="task_id",
+ replace=True,
+ pd_kwargs={"index": False, "header": False},
+ max_rows_per_file=3,
+ dag=None,
+ )
+ example = {
+ "Team": ["Australia", "Australia", "India", "India"],
+ "Player": ["Ricky", "David Warner", "Virat Kohli", "Rohit Sharma"],
+ "Runs": [345, 490, 672, 560],
+ }
+
+ df = pd.DataFrame(example)
+ data = []
+ for group_name, df in op._partition_dataframe(df):
+ data.append((group_name, df))
+ data.sort(key=lambda d: d[0])
+ team, df = data[0]
+ assert df.equals(
+ pd.DataFrame(
+ {
+ "Team": ["Australia", "Australia", "India"],
+ "Player": ["Ricky", "David Warner", "Virat Kohli"],
+ "Runs": [345, 490, 672],
+ }
+ )
+ )
+ team, df = data[1]
+ assert df.equals(
+ pd.DataFrame(
+ {
+ "Team": ["India"],
+ "Player": ["Rohit Sharma"],
+ "Runs": [560],
+ }
+ )
+ )
+
@mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection")
def test_hook_params(self, mock_get_conn):
mock_get_conn.return_value = Connection(conn_id="postgres_test",
conn_type="postgres")