@Niranda Perera thanks for the clarification
@Weston Pace <[email protected]> thanks for the detailed answer, each
bit was very informative!
Indeed, as you mentioned, wrapping the scanner object is necessary
otherwise ds.write_dataset will not perform any consolidation, so my first
example transforms but does not consolidate (eg if the input has several
parquet files in a partition they **will not** be saved as a single file in
the output).
I would like to implement a mechanism so I can pass arbitrary
transformations, similar to what you mention regarding the future python
UDF support.
Right now, since I need to transform both the batch and the schema I put
both transformations in a single object and come up with the following
working example based on what you suggested (code appended below). To
implement a new column-wise function the user has to implement a method
taking a pandas df and returning a series, the object will do the rest.
Now, this workaround to tie together batch and schema modifications with an
ad-hoc object works. But, ideally, I would like to create a new modified
scanner object that will perform the transformation under the hood and will
expose the proper scanner API. In this way ds.write_dataset will not need
the additional schema argument. It would be nice to modify a scanner object
through a decorator. Would this be possible?
Also, there will be any potential issues when the dataset is saved to a
cloud fs such as S3?
Best,
Antonio
Example [1]
```
from abc import ABC, abstractmethod
import shutil
from typing import Optional
import pyarrow.dataset as ds
import pyarrow as pa
import pyarrow.parquet as pq
from glob import glob
from uuid import uuid4
import pandas as pd
def file_visitor(written_file):
print(f"Saved partitioned file file_path={str(written_file.path)}")
# table = pq.read_table(written_file.path)
parquet_file = pq.ParquetFile(written_file.path)
print(f'{parquet_file.metadata=}')
print(f'{parquet_file.schema=}')
def append_column_to_batch(batch: pa.RecordBatch, arr: pa.Array, name: str)
-> pa.RecordBatch:
"""
Utility function to append a column to a RecordBatch.
This is useful until pa.RecordBatches does add an `.append_column`
method
"""
# NOTE: It's a metadata-only (zero-copy) operation so shouldn't take
much time
tab = pa.Table.from_batches([batch])
field = pa.field(name, arr.type)
new_tab = tab.append_column(field, arr)
return new_tab.to_batches()[0]
in_arrow_path = 'example_input'
out_arrow_path = 'example_output'
# Create the input dataset
data_dict1 = {'partition': [1, 1, 2, 2],
'a': [1, 2, 3, 4],
'b': [2, 4, 6, 8],
'c': [10, 11, 12, 13]}
data_dict2 = {'partition': [1, 1, 2, 2],
'a': [5, 6, 7, 8],
'b': [10, 12, 14, 16],
'c': [20, 21, 22, 23]}
table1 = pa.Table.from_pydict(data_dict1)
table2 = pa.Table.from_pydict(data_dict2)
ds.write_dataset(table1, in_arrow_path, format='parquet',
partitioning=['partition'],
partitioning_flavor='hive',
existing_data_behavior='delete_matching',
basename_template=f'{uuid4()}-{{i}}.parquet')
ds.write_dataset(table2, in_arrow_path, format='parquet',
partitioning=['partition'],
partitioning_flavor='hive',
existing_data_behavior='overwrite_or_ignore',
basename_template=f'{uuid4()}-{{i}}.parquet')
print('\n'.join(glob(f'{in_arrow_path}/**/*')))
dataset = ds.dataset(in_arrow_path, partitioning='hive')
print(dataset.to_table().to_pandas())
shutil.rmtree(out_arrow_path, ignore_errors=True)
# Re-save the input dataset adding a new column ("consolidation")
scanner = dataset.scanner()
## Transformer
class ScannerTransformer(ABC):
"""
Holds the logic to wrap a scanner to apply a transformation to each
batch
and the logic to compute the modified schema. The schema is needed by
ds.write_dataset when receiving an iterator of batches instead of a
scanner
"""
def __init__(self,
scanner: ds.Scanner,
input_cols: list[str],
output_col: str,
output_col_type: str,
transform_df_args: Optional[dict] = None,
) -> None:
self.scanner = scanner
self.input_cols = input_cols
self.output_col = output_col
self.output_col_type = output_col_type
self.transform_df_args = transform_df_args if transform_df_args
else {}
def transform_scanner(self) -> ds.Scanner:
"""
Yields a transformed pa.RecordBatch at each iteration
"""
for batch in self.scanner.to_batches():
new_batch = self.transform_batch(batch)
yield new_batch
def transform_schema(self) -> pa.Schema:
new_field = pa.field(self.output_col, self.output_col_type)
schema = self.scanner.projected_schema.append(new_field)
return schema
@property
def schema(self) -> pa.Schema:
if not hasattr(self, '_schema'):
self._schema = self.transform_schema()
return self._schema
def transform_batch(self, batch: pa.RecordBatch) -> pa.RecordBatch:
table = pa.Table.from_batches([batch])
table = table.select(self.input_cols)
df = table.to_pandas()
out_series = (self.transform_df(df, **self.transform_df_args)
.astype(self.output_col_type))
out_array = pa.Array.from_pandas(out_series)
new_batch = append_column_to_batch(batch, out_array,
self.output_col)
return new_batch
@abstractmethod
def transform_df(self, df: pd.DataFrame, **kwargs) -> pd.Series:
...
class AppendAggCol(ScannerTransformer):
def transform_df(self, df: pd.DataFrame, agg_func: str = 'mean',
) -> pd.Series:
out_series = df.agg(agg_func, axis=1)
return out_series
transformer = AppendAggCol(scanner, input_cols=['a', 'b'],
output_col='out', output_col_type='float64',
transform_df_args={'agg_func': 'mean'})
wrapped_scanner = transformer.transform_scanner()
print('-----------------------------')
ds.write_dataset(wrapped_scanner, out_arrow_path,
schema=transformer.schema,
format='parquet',
partitioning=['partition'],
partitioning_flavor='hive',
existing_data_behavior='overwrite_or_ignore',
basename_template=f'{uuid4()}-{{i}}.parquet',
file_visitor=file_visitor)
print('\n'.join(glob(f'{out_arrow_path}/**/*')))
dataset = ds.dataset(out_arrow_path, partitioning='hive')
print(dataset.to_table().to_pandas())
```