This is an automated email from the ASF dual-hosted git repository.

honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 1016b197 Implement __getstate__ and __setstate__ so that FileIO 
instances can be pickled (#543)
1016b197 is described below

commit 1016b197096472599c42228205f8a23f9ed28765
Author: Amogh Jahagirdar <am...@tabular.io>
AuthorDate: Sun Apr 7 15:14:46 2024 -0600

    Implement __getstate__ and __setstate__ so that FileIO instances can be 
pickled (#543)
---
 pyiceberg/io/fsspec.py  | 12 ++++++++++++
 pyiceberg/io/pyarrow.py | 12 ++++++++++++
 tests/io/test_fsspec.py | 31 +++++++++++++++++++++++++++++++
 tests/io/test_io.py     | 20 ++++++++++++++++++++
 4 files changed, 75 insertions(+)

diff --git a/pyiceberg/io/fsspec.py b/pyiceberg/io/fsspec.py
index 957cac66..ee97829c 100644
--- a/pyiceberg/io/fsspec.py
+++ b/pyiceberg/io/fsspec.py
@@ -20,6 +20,7 @@ import errno
 import json
 import logging
 import os
+from copy import copy
 from functools import lru_cache, partial
 from typing import (
     Any,
@@ -338,3 +339,14 @@ class FsspecFileIO(FileIO):
         if scheme not in self._scheme_to_fs:
             raise ValueError(f"No registered filesystem for scheme: {scheme}")
         return self._scheme_to_fs[scheme](self.properties)
+
+    def __getstate__(self) -> Dict[str, Any]:
+        """Create a dictionary of the FsSpecFileIO fields used when 
pickling."""
+        fileio_copy = copy(self.__dict__)
+        fileio_copy["get_fs"] = None
+        return fileio_copy
+
+    def __setstate__(self, state: Dict[str, Any]) -> None:
+        """Deserialize the state into a FsSpecFileIO instance."""
+        self.__dict__ = state
+        self.get_fs = lru_cache(self._get_fs)
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 06d03e21..1848fba7 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -33,6 +33,7 @@ import os
 import re
 from abc import ABC, abstractmethod
 from concurrent.futures import Future
+from copy import copy
 from dataclasses import dataclass
 from enum import Enum
 from functools import lru_cache, singledispatch
@@ -456,6 +457,17 @@ class PyArrowFileIO(FileIO):
                 raise PermissionError(f"Cannot delete file, access denied: 
{location}") from e
             raise  # pragma: no cover - If some other kind of OSError, raise 
the raw error
 
+    def __getstate__(self) -> Dict[str, Any]:
+        """Create a dictionary of the PyArrowFileIO fields used when 
pickling."""
+        fileio_copy = copy(self.__dict__)
+        fileio_copy["fs_by_scheme"] = None
+        return fileio_copy
+
+    def __setstate__(self, state: Dict[str, Any]) -> None:
+        """Deserialize the state into a PyArrowFileIO instance."""
+        self.__dict__ = state
+        self.fs_by_scheme = lru_cache(self._initialize_fs)
+
 
 def schema_to_pyarrow(schema: Union[Schema, IcebergType], metadata: 
Dict[bytes, bytes] = EMPTY_DICT) -> pa.schema:
     return visit(schema, _ConvertToArrowSchema(metadata))
diff --git a/tests/io/test_fsspec.py b/tests/io/test_fsspec.py
index 9f044454..013fc5d8 100644
--- a/tests/io/test_fsspec.py
+++ b/tests/io/test_fsspec.py
@@ -16,6 +16,7 @@
 # under the License.
 
 import os
+import pickle
 import tempfile
 import uuid
 
@@ -229,6 +230,11 @@ def test_writing_avro_file(generated_manifest_entry_file: 
str, fsspec_fileio: Fs
     fsspec_fileio.delete(f"s3://warehouse/{filename}")
 
 
+@pytest.mark.s3
+def test_fsspec_pickle_round_trip_s3(fsspec_fileio: FsspecFileIO) -> None:
+    _test_fsspec_pickle_round_trip(fsspec_fileio, "s3://warehouse/foo.txt")
+
+
 @pytest.mark.adlfs
 def test_fsspec_new_input_file_adlfs(adlfs_fsspec_fileio: FsspecFileIO) -> 
None:
     """Test creating a new input file from an fsspec file-io"""
@@ -410,6 +416,11 @@ def 
test_writing_avro_file_adlfs(generated_manifest_entry_file: str, adlfs_fsspe
     adlfs_fsspec_fileio.delete(f"abfss://tests/{filename}")
 
 
+@pytest.mark.adlfs
+def test_fsspec_pickle_round_trip_aldfs(adlfs_fsspec_fileio: FsspecFileIO) -> 
None:
+    _test_fsspec_pickle_round_trip(adlfs_fsspec_fileio, 
"abfss://tests/foo.txt")
+
+
 @pytest.mark.gcs
 def test_fsspec_new_input_file_gcs(fsspec_fileio_gcs: FsspecFileIO) -> None:
     """Test creating a new input file from a fsspec file-io"""
@@ -586,6 +597,26 @@ def 
test_writing_avro_file_gcs(generated_manifest_entry_file: str, fsspec_fileio
     fsspec_fileio_gcs.delete(f"gs://warehouse/{filename}")
 
 
+@pytest.mark.gcs
+def test_fsspec_pickle_roundtrip_gcs(fsspec_fileio_gcs: FsspecFileIO) -> None:
+    _test_fsspec_pickle_round_trip(fsspec_fileio_gcs, "gs://warehouse/foo.txt")
+
+
+def _test_fsspec_pickle_round_trip(fsspec_fileio: FsspecFileIO, location: str) 
-> None:
+    serialized_file_io = pickle.dumps(fsspec_fileio)
+    deserialized_file_io = pickle.loads(serialized_file_io)
+    output_file = deserialized_file_io.new_output(location)
+    with output_file.create() as f:
+        f.write(b"foo")
+
+    input_file = deserialized_file_io.new_input(location)
+    with input_file.open() as f:
+        data = f.read()
+        assert data == b"foo"
+        assert len(input_file) == 3
+    deserialized_file_io.delete(location)
+
+
 TEST_URI = "https://iceberg-test-signer";
 
 
diff --git a/tests/io/test_io.py b/tests/io/test_io.py
index d27274e1..b273288b 100644
--- a/tests/io/test_io.py
+++ b/tests/io/test_io.py
@@ -16,6 +16,7 @@
 # under the License.
 
 import os
+import pickle
 import tempfile
 
 import pytest
@@ -71,6 +72,25 @@ def test_custom_local_output_file() -> None:
         assert len(output_file) == 3
 
 
+def test_pickled_pyarrow_round_trip() -> None:
+    with tempfile.TemporaryDirectory() as tmpdirname:
+        file_location = os.path.join(tmpdirname, "foo.txt")
+        file_io = PyArrowFileIO()
+        serialized_file_io = pickle.dumps(file_io)
+        deserialized_file_io = pickle.loads(serialized_file_io)
+        absolute_file_location = os.path.abspath(file_location)
+        output_file = 
deserialized_file_io.new_output(location=f"{absolute_file_location}")
+        with output_file.create() as f:
+            f.write(b"foo")
+
+        input_file = 
deserialized_file_io.new_input(location=f"{absolute_file_location}")
+        f = input_file.open()
+        data = f.read()
+        assert data == b"foo"
+        assert len(input_file) == 3
+        deserialized_file_io.delete(location=f"{absolute_file_location}")
+
+
 def test_custom_local_output_file_with_overwrite() -> None:
     """Test initializing an OutputFile implementation to overwrite a local 
file"""
     with tempfile.TemporaryDirectory() as tmpdirname:

Reply via email to