This is an automated email from the ASF dual-hosted git repository.
jorisvandenbossche pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new a63ead769a GH-34884: [Python]: Support pickling pyarrow.dataset
PartitioningFactory objects (#36550)
a63ead769a is described below
commit a63ead769ae5c7e4f2a32f528bcb39c5aae2ab49
Author: Joris Van den Bossche <[email protected]>
AuthorDate: Mon Jul 10 08:32:40 2023 +0200
GH-34884: [Python]: Support pickling pyarrow.dataset PartitioningFactory
objects (#36550)
### Rationale for this change
https://github.com/apache/arrow/pull/36462 already added support for
pickling Partitioning objects, but not yet the PartitioningFactory objects.
The problem for PartitioningFactory is that we currently don't really
expose the full class hierarchy in python, just the base class
PartitioningFactory. We also don't expose creating those factory objects,
except through the `discover` methods of the Partitioning classes.
I think it would be nice to keep this minimal binding, but that means if we
want to make them serializable with pickle, we need another way to do that (and
if we don't want to add custom code for serialization on the C++ side).
In this PR, I went for the route of essentially storing the constructor
(the discover static method) and the arguments that were passed to the
constructor, on the factory object, so we can use this info for pickling. Not
the nicest code, but the simplest solution I could think of.
### Are these changes tested?
Yes
* Closes: #34884
Authored-by: Joris Van den Bossche <[email protected]>
Signed-off-by: Joris Van den Bossche <[email protected]>
---
python/pyarrow/_dataset.pxd | 5 ++-
python/pyarrow/_dataset.pyx | 39 +++++++++++++++++++++---
python/pyarrow/_dataset_parquet.pyx | 2 +-
python/pyarrow/tests/test_dataset.py | 59 +++++++++++++++++++++++++-----------
4 files changed, 80 insertions(+), 25 deletions(-)
diff --git a/python/pyarrow/_dataset.pxd b/python/pyarrow/_dataset.pxd
index d626b42e23..210e555800 100644
--- a/python/pyarrow/_dataset.pxd
+++ b/python/pyarrow/_dataset.pxd
@@ -160,11 +160,14 @@ cdef class PartitioningFactory(_Weakrefable):
cdef:
shared_ptr[CPartitioningFactory] wrapped
CPartitioningFactory* factory
+ object constructor
+ object options
cdef init(self, const shared_ptr[CPartitioningFactory]& sp)
@staticmethod
- cdef wrap(const shared_ptr[CPartitioningFactory]& sp)
+ cdef wrap(const shared_ptr[CPartitioningFactory]& sp,
+ object constructor, object options)
cdef inline shared_ptr[CPartitioningFactory] unwrap(self)
diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx
index 2ab8ffb798..c5f0a663a8 100644
--- a/python/pyarrow/_dataset.pyx
+++ b/python/pyarrow/_dataset.pyx
@@ -2374,16 +2374,22 @@ cdef class PartitioningFactory(_Weakrefable):
self.factory = sp.get()
@staticmethod
- cdef wrap(const shared_ptr[CPartitioningFactory]& sp):
+ cdef wrap(const shared_ptr[CPartitioningFactory]& sp,
+ object constructor, object options):
cdef PartitioningFactory self = PartitioningFactory.__new__(
PartitioningFactory
)
self.init(sp)
+ self.constructor = constructor
+ self.options = options
return self
cdef inline shared_ptr[CPartitioningFactory] unwrap(self):
return self.wrapped
+ def __reduce__(self):
+ return self.constructor, self.options
+
@property
def type_name(self):
return frombytes(self.factory.type_name())
@@ -2454,6 +2460,10 @@ cdef class KeyValuePartitioning(Partitioning):
return res
+def _constructor_directory_partitioning_factory(*args):
+ return DirectoryPartitioning.discover(*args)
+
+
cdef class DirectoryPartitioning(KeyValuePartitioning):
"""
A Partitioning based on a specified Schema.
@@ -2571,7 +2581,15 @@ cdef class DirectoryPartitioning(KeyValuePartitioning):
c_options.segment_encoding = _get_segment_encoding(segment_encoding)
return PartitioningFactory.wrap(
- CDirectoryPartitioning.MakeFactory(c_field_names, c_options))
+ CDirectoryPartitioning.MakeFactory(c_field_names, c_options),
+ _constructor_directory_partitioning_factory,
+ (field_names, infer_dictionary, max_partition_dictionary_size,
+ schema, segment_encoding)
+ )
+
+
+def _constructor_hive_partitioning_factory(*args):
+ return HivePartitioning.discover(*args)
cdef class HivePartitioning(KeyValuePartitioning):
@@ -2714,7 +2732,15 @@ cdef class HivePartitioning(KeyValuePartitioning):
c_options.segment_encoding = _get_segment_encoding(segment_encoding)
return PartitioningFactory.wrap(
- CHivePartitioning.MakeFactory(c_options))
+ CHivePartitioning.MakeFactory(c_options),
+ _constructor_hive_partitioning_factory,
+ (infer_dictionary, max_partition_dictionary_size, null_fallback,
+ schema, segment_encoding),
+ )
+
+
+def _constructor_filename_partitioning_factory(*args):
+ return FilenamePartitioning.discover(*args)
cdef class FilenamePartitioning(KeyValuePartitioning):
@@ -2823,7 +2849,10 @@ cdef class FilenamePartitioning(KeyValuePartitioning):
c_options.segment_encoding = _get_segment_encoding(segment_encoding)
return PartitioningFactory.wrap(
- CFilenamePartitioning.MakeFactory(c_field_names, c_options))
+ CFilenamePartitioning.MakeFactory(c_field_names, c_options),
+ _constructor_filename_partitioning_factory,
+ (field_names, infer_dictionary, schema, segment_encoding)
+ )
cdef class DatasetFactory(_Weakrefable):
@@ -2988,7 +3017,7 @@ cdef class FileSystemFactoryOptions(_Weakrefable):
c_factory = self.options.partitioning.factory()
if c_factory.get() == nullptr:
return None
- return PartitioningFactory.wrap(c_factory)
+ return PartitioningFactory.wrap(c_factory, None, None)
@partitioning_factory.setter
def partitioning_factory(self, PartitioningFactory value):
diff --git a/python/pyarrow/_dataset_parquet.pyx
b/python/pyarrow/_dataset_parquet.pyx
index bc4786b9cd..4ad0caec30 100644
--- a/python/pyarrow/_dataset_parquet.pyx
+++ b/python/pyarrow/_dataset_parquet.pyx
@@ -811,7 +811,7 @@ cdef class ParquetFactoryOptions(_Weakrefable):
c_factory = self.options.partitioning.factory()
if c_factory.get() == nullptr:
return None
- return PartitioningFactory.wrap(c_factory)
+ return PartitioningFactory.wrap(c_factory, None, None)
@partitioning_factory.setter
def partitioning_factory(self, PartitioningFactory value):
diff --git a/python/pyarrow/tests/test_dataset.py
b/python/pyarrow/tests/test_dataset.py
index 814454861e..2f9b6a0922 100644
--- a/python/pyarrow/tests/test_dataset.py
+++ b/python/pyarrow/tests/test_dataset.py
@@ -1642,12 +1642,15 @@ def test_fragments_repr(tempdir, dataset):
@pytest.mark.parquet
-def test_partitioning_factory(mockfs):
[email protected](
+ "pickled", [lambda x: x, lambda x: pickle.loads(pickle.dumps(x))])
+def test_partitioning_factory(mockfs, pickled):
paths_or_selector = fs.FileSelector('subdir', recursive=True)
format = ds.ParquetFileFormat()
options = ds.FileSystemFactoryOptions('subdir')
partitioning_factory = ds.DirectoryPartitioning.discover(['group', 'key'])
+ partitioning_factory = pickled(partitioning_factory)
assert isinstance(partitioning_factory, ds.PartitioningFactory)
options.partitioning_factory = partitioning_factory
@@ -1673,13 +1676,16 @@ def test_partitioning_factory(mockfs):
@pytest.mark.parquet
@pytest.mark.parametrize('infer_dictionary', [False, True])
-def test_partitioning_factory_dictionary(mockfs, infer_dictionary):
[email protected](
+ "pickled", [lambda x: x, lambda x: pickle.loads(pickle.dumps(x))])
+def test_partitioning_factory_dictionary(mockfs, infer_dictionary, pickled):
paths_or_selector = fs.FileSelector('subdir', recursive=True)
format = ds.ParquetFileFormat()
options = ds.FileSystemFactoryOptions('subdir')
- options.partitioning_factory = ds.DirectoryPartitioning.discover(
+ partitioning_factory = ds.DirectoryPartitioning.discover(
['group', 'key'], infer_dictionary=infer_dictionary)
+ options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(
mockfs, paths_or_selector, format, options)
@@ -1703,7 +1709,9 @@ def test_partitioning_factory_dictionary(mockfs,
infer_dictionary):
assert inferred_schema.field('key').type == pa.string()
-def test_partitioning_factory_segment_encoding():
[email protected](
+ "pickled", [lambda x: x, lambda x: pickle.loads(pickle.dumps(x))])
+def test_partitioning_factory_segment_encoding(pickled):
mockfs = fs._MockFileSystem()
format = ds.IpcFileFormat()
schema = pa.schema([("i64", pa.int64())])
@@ -1726,8 +1734,9 @@ def test_partitioning_factory_segment_encoding():
# Directory
selector = fs.FileSelector("directory", recursive=True)
options = ds.FileSystemFactoryOptions("directory")
- options.partitioning_factory = ds.DirectoryPartitioning.discover(
+ partitioning_factory = ds.DirectoryPartitioning.discover(
schema=partition_schema)
+ options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
inferred_schema = factory.inspect()
assert inferred_schema == full_schema
@@ -1736,24 +1745,27 @@ def test_partitioning_factory_segment_encoding():
})
assert actual[0][0].as_py() == 1620086400
- options.partitioning_factory = ds.DirectoryPartitioning.discover(
+ partitioning_factory = ds.DirectoryPartitioning.discover(
["date", "string"], segment_encoding="none")
+ options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
fragments = list(factory.finish().get_fragments())
assert fragments[0].partition_expression.equals(
(ds.field("date") == "2021-05-04 00%3A00%3A00") &
(ds.field("string") == "%24"))
- options.partitioning = ds.DirectoryPartitioning(
+ partitioning = ds.DirectoryPartitioning(
string_partition_schema, segment_encoding="none")
+ options.partitioning = pickled(partitioning)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
fragments = list(factory.finish().get_fragments())
assert fragments[0].partition_expression.equals(
(ds.field("date") == "2021-05-04 00%3A00%3A00") &
(ds.field("string") == "%24"))
- options.partitioning_factory = ds.DirectoryPartitioning.discover(
+ partitioning_factory = ds.DirectoryPartitioning.discover(
schema=partition_schema, segment_encoding="none")
+ options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
with pytest.raises(pa.ArrowInvalid,
match="Could not cast segments for partition field"):
@@ -1762,8 +1774,9 @@ def test_partitioning_factory_segment_encoding():
# Hive
selector = fs.FileSelector("hive", recursive=True)
options = ds.FileSystemFactoryOptions("hive")
- options.partitioning_factory = ds.HivePartitioning.discover(
+ partitioning_factory = ds.HivePartitioning.discover(
schema=partition_schema)
+ options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
inferred_schema = factory.inspect()
assert inferred_schema == full_schema
@@ -1772,8 +1785,9 @@ def test_partitioning_factory_segment_encoding():
})
assert actual[0][0].as_py() == 1620086400
- options.partitioning_factory = ds.HivePartitioning.discover(
+ partitioning_factory = ds.HivePartitioning.discover(
segment_encoding="none")
+ options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
fragments = list(factory.finish().get_fragments())
assert fragments[0].partition_expression.equals(
@@ -1788,15 +1802,18 @@ def test_partitioning_factory_segment_encoding():
(ds.field("date") == "2021-05-04 00%3A00%3A00") &
(ds.field("string") == "%24"))
- options.partitioning_factory = ds.HivePartitioning.discover(
+ partitioning_factory = ds.HivePartitioning.discover(
schema=partition_schema, segment_encoding="none")
+ options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
with pytest.raises(pa.ArrowInvalid,
match="Could not cast segments for partition field"):
inferred_schema = factory.inspect()
-def test_partitioning_factory_hive_segment_encoding_key_encoded():
[email protected](
+ "pickled", [lambda x: x, lambda x: pickle.loads(pickle.dumps(x))])
+def test_partitioning_factory_hive_segment_encoding_key_encoded(pickled):
mockfs = fs._MockFileSystem()
format = ds.IpcFileFormat()
schema = pa.schema([("i64", pa.int64())])
@@ -1825,8 +1842,9 @@ def
test_partitioning_factory_hive_segment_encoding_key_encoded():
# Hive
selector = fs.FileSelector("hive", recursive=True)
options = ds.FileSystemFactoryOptions("hive")
- options.partitioning_factory = ds.HivePartitioning.discover(
+ partitioning_factory = ds.HivePartitioning.discover(
schema=partition_schema)
+ options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
inferred_schema = factory.inspect()
assert inferred_schema == full_schema
@@ -1835,40 +1853,45 @@ def
test_partitioning_factory_hive_segment_encoding_key_encoded():
})
assert actual[0][0].as_py() == 1620086400
- options.partitioning_factory = ds.HivePartitioning.discover(
+ partitioning_factory = ds.HivePartitioning.discover(
segment_encoding="uri")
+ options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
fragments = list(factory.finish().get_fragments())
assert fragments[0].partition_expression.equals(
(ds.field("test'; date") == "2021-05-04 00:00:00") &
(ds.field("test';[ string'") == "$"))
- options.partitioning = ds.HivePartitioning(
+ partitioning = ds.HivePartitioning(
string_partition_schema, segment_encoding="uri")
+ options.partitioning = pickled(partitioning)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
fragments = list(factory.finish().get_fragments())
assert fragments[0].partition_expression.equals(
(ds.field("test'; date") == "2021-05-04 00:00:00") &
(ds.field("test';[ string'") == "$"))
- options.partitioning_factory = ds.HivePartitioning.discover(
+ partitioning_factory = ds.HivePartitioning.discover(
segment_encoding="none")
+ options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
fragments = list(factory.finish().get_fragments())
assert fragments[0].partition_expression.equals(
(ds.field("test%27%3B%20date") == "2021-05-04 00%3A00%3A00") &
(ds.field("test%27%3B%5B%20string%27") == "%24"))
- options.partitioning = ds.HivePartitioning(
+ partitioning = ds.HivePartitioning(
string_partition_schema_en, segment_encoding="none")
+ options.partitioning = pickled(partitioning)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
fragments = list(factory.finish().get_fragments())
assert fragments[0].partition_expression.equals(
(ds.field("test%27%3B%20date") == "2021-05-04 00%3A00%3A00") &
(ds.field("test%27%3B%5B%20string%27") == "%24"))
- options.partitioning_factory = ds.HivePartitioning.discover(
+ partitioning_factory = ds.HivePartitioning.discover(
schema=partition_schema_en, segment_encoding="none")
+ options.partitioning_factory = pickled(partitioning_factory)
factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
with pytest.raises(pa.ArrowInvalid,
match="Could not cast segments for partition field"):