This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ecf179c3485b [SPARK-54337][PS] Add support for PyCapsule to Pyspark
ecf179c3485b is described below
commit ecf179c3485ba8bac72afd9105892d9798d23f8f
Author: Devin Petersohn <[email protected]>
AuthorDate: Mon Jan 12 09:07:51 2026 +0800
[SPARK-54337][PS] Add support for PyCapsule to Pyspark
### What changes were proposed in this pull request?
Add support for Pycapsule protocol to facilitate interchange between Spark
and other Python libraries. Here is a demo of what this enables with Polars and
DuckDB:
```
Welcome to
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/__ / .__/\_,_/_/ /_/\_\ version 4.2.0-SNAPSHOT
/_/
Using Python version 3.11.5 (main, Sep 11 2023 08:31:25)
Spark context Web UI available at http://192.168.86.83:4040
Spark context available as 'sc' (master = local[*], app id =
local-1765227291836).
SparkSession available as 'spark'.
In [1]: import pyspark.pandas as ps
...: import pandas as pd
...: import numpy as np
...: import polars as pl
...:
...: pdf = pd.DataFrame(
...: {"A": [True, False], "B": [1, np.nan], "C": [True, None], "D":
[None, np.nan]}
...: )
...: psdf = ps.from_pandas(pdf)
...: polars_df = pl.DataFrame(psdf)
/Users/dpetersohn/software_sources/spark/python/pyspark/pandas/__init__.py:43:
UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is
required to set this environment variable to '1' in both driver and executor
sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it
does not work if there is a Spark context already launched.
warnings.warn(
[Stage 0:> (0 + 1)
/ 1]
In [2]: polars_df
Out[2]:
shape: (2, 5)
┌───────────────────┬───────┬──────┬──────┬──────┐
│ __index_level_0__ ┆ A ┆ B ┆ C ┆ D │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ bool ┆ f64 ┆ bool ┆ f64 │
╞═══════════════════╪═══════╪══════╪══════╪══════╡
│ 0 ┆ true ┆ 1.0 ┆ true ┆ null │
│ 1 ┆ false ┆ null ┆ null ┆ null │
└───────────────────┴───────┴──────┴──────┴──────┘
In [3]: import duckdb
In [4]: import pyarrow as pa
In [5]: stream = pa.RecordBatchReader.from_stream(psdf)
In [6]: duckdb.sql("SELECT count(*) AS total, avg(B) FROM stream WHERE B IS
NOT NULL").fetchall()
Out[6]: [(1, 1.0)]
```
Polars will now be able to consume a full Pyspark dataframe (or
`pyspark.pandas`), and DuckDB can consume a stream built from the Pyspark
dataframe. Importantly, the `stream = pa.RecordBatchReader.from_stream(psdf)`
line does not trigger any computation, it simply creates a stream object which
is incrementally consumed by DuckDB when the `fetchall` call is executed.
### Why are the changes needed?
Currently, Pyspark (and to a lesser degree Pyspark pandas) does not
integrate well with the broader Python ecosystem. Currently, the best practice
is to go through pandas with `toPandas`, but that materializes all data on the
driver all at once. This new API and protocol allows data to stream, one Arrow
Batch at a time, enabling libraries like DuckDB and Polars to consume the data
as a stream.
### Does this PR introduce _any_ user-facing change?
Yes, new user-level API.
### How was this patch tested?
Locally
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #53391 from devin-petersohn/devin/pycapsule.
Lead-authored-by: Devin Petersohn <[email protected]>
Co-authored-by: Devin Petersohn <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
dev/sparktestsupport/modules.py | 1 +
python/pyspark/pandas/frame.py | 19 +++++
python/pyspark/sql/dataframe.py | 17 +++++
python/pyspark/sql/interchange.py | 89 ++++++++++++++++++++++
.../pyspark/sql/tests/arrow/test_arrow_c_stream.py | 64 ++++++++++++++++
5 files changed, 190 insertions(+)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 4e956314c3d8..0ff9d6634377 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -551,6 +551,7 @@ pyspark_sql = Module(
"pyspark.sql.tests.test_job_cancellation",
"pyspark.sql.tests.arrow.test_arrow",
"pyspark.sql.tests.arrow.test_arrow_map",
+ "pyspark.sql.tests.arrow.test_arrow_c_stream",
"pyspark.sql.tests.arrow.test_arrow_cogrouped_map",
"pyspark.sql.tests.arrow.test_arrow_grouped_map",
"pyspark.sql.tests.arrow.test_arrow_python_udf",
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index df68e31d4f33..23ac31c8ebfb 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -13823,6 +13823,25 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
# we always wraps the given type hints by a tuple to mimic the
variadic generic.
return create_tuple_for_frame_type(params)
+ def __arrow_c_stream__(self, requested_schema: Optional[object] = None) ->
object:
+ """
+ Export to a C PyCapsule stream object.
+
+ Parameters
+ ----------
+ requested_schema : PyCapsule, optional
+ The schema to attempt to use for the output stream. This is a best
effort request,
+
+ Returns
+ -------
+ A C PyCapsule stream object.
+ """
+ from pyspark.sql.interchange import SparkArrowCStreamer
+
+ return
SparkArrowCStreamer(self._internal.to_internal_spark_frame).__arrow_c_stream__(
+ requested_schema
+ )
+
def _reduce_spark_multi(sdf: PySparkDataFrame, aggs: List[PySparkColumn]) ->
Any:
"""
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 7b7547b68ff2..2ddfdda762d7 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -6982,6 +6982,23 @@ class DataFrame:
"""
...
+ def __arrow_c_stream__(self, requested_schema: Optional[object] = None) ->
object:
+ """
+ Export to a C PyCapsule stream object.
+
+ Parameters
+ ----------
+ requested_schema : PyCapsule, optional
+ The schema to attempt to use for the output stream. This is a best
effort request,
+
+ Returns
+ -------
+ A C PyCapsule stream object.
+ """
+ from pyspark.sql.interchange import SparkArrowCStreamer
+
+ return SparkArrowCStreamer(self).__arrow_c_stream__(requested_schema)
+
class DataFrameNaFunctions:
"""Functionality for working with missing data in :class:`DataFrame`.
diff --git a/python/pyspark/sql/interchange.py
b/python/pyspark/sql/interchange.py
new file mode 100644
index 000000000000..141d9f37148e
--- /dev/null
+++ b/python/pyspark/sql/interchange.py
@@ -0,0 +1,89 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Iterator, Optional
+
+import pyarrow as pa
+
+import pyspark.sql
+from pyspark.sql.types import StructType, StructField, BinaryType
+from pyspark.sql.pandas.types import to_arrow_schema
+
+
+def _get_arrow_array_partition_stream(df: pyspark.sql.DataFrame) ->
Iterator[pa.RecordBatch]:
+ """Return all the partitions as Arrow arrays in an Iterator."""
+ # We will be using mapInArrow to convert each partition to Arrow
RecordBatches.
+ # The return type of the function will be a single binary column containing
+ # the serialized RecordBatch in Arrow IPC format.
+ binary_schema = StructType([StructField("arrow_ipc_bytes", BinaryType(),
nullable=False)])
+
+ def batch_to_bytes_iter(batch_iter: Iterator[pa.RecordBatch]) ->
Iterator[pa.RecordBatch]:
+ """
+ A generator function that converts RecordBatches to serialized Arrow
IPC format.
+
+ Spark sends each partition as an iterator of RecordBatches. In order
to return
+ the entire partition as a stream of Arrow RecordBatches, we need to
serialize
+ each RecordBatch to Arrow IPC format and yield it as a single binary
blob.
+ """
+ # The size of the batch can be controlled by the Spark config
+ # `spark.sql.execution.arrow.maxRecordsPerBatch`.
+ for arrow_batch in batch_iter:
+ # We create an in-memory byte stream to hold the serialized batch
+ sink = pa.BufferOutputStream()
+ # Write the batch to the stream using Arrow IPC format
+ with pa.ipc.new_stream(sink, arrow_batch.schema) as writer:
+ writer.write_batch(arrow_batch)
+ buf = sink.getvalue()
+ # The second buffer contains the offsets we are manually creating.
+ offset_buf = pa.array([0, len(buf)], type=pa.int32()).buffers()[1]
+ null_bitmap = None
+ # Wrap the bytes in a new 1-row, 1-column RecordBatch to satisfy
mapInArrow return
+ # signature. This serializes the whole batch into a single pyarrow
serialized cell.
+ storage_arr = pa.Array.from_buffers(
+ type=pa.binary(), length=1, buffers=[null_bitmap, offset_buf,
buf]
+ )
+ yield pa.RecordBatch.from_arrays([storage_arr],
names=["arrow_ipc_bytes"])
+
+ # Convert all partitions to Arrow RecordBatches and map to binary blobs.
+ byte_df = df.mapInArrow(batch_to_bytes_iter, binary_schema)
+ # A row is actually a batch of data in Arrow IPC format. Fetch the batches
one by one.
+ for row in byte_df.toLocalIterator():
+ with pa.ipc.open_stream(row.arrow_ipc_bytes) as reader:
+ for batch in reader:
+ # Each batch corresponds to a chunk of data in the partition.
+ yield batch
+
+
+class SparkArrowCStreamer:
+ """
+ A class that implements that __arrow_c_stream__ protocol for Spark
partitions.
+
+ This class is implemented in a way that allows consumers to consume each
partition
+ one at a time without materializing all partitions at once on the driver
side.
+ """
+
+ def __init__(self, df: pyspark.sql.DataFrame):
+ self._df = df
+ self._schema = to_arrow_schema(df.schema)
+
+ def __arrow_c_stream__(self, requested_schema: Optional[object] = None) ->
object:
+ """
+ Return the Arrow C stream for the dataframe partitions.
+ """
+ reader: pa.RecordBatchReader = pa.RecordBatchReader.from_batches(
+ self._schema, _get_arrow_array_partition_stream(self._df)
+ )
+ return reader.__arrow_c_stream__(requested_schema=requested_schema)
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_c_stream.py
b/python/pyspark/sql/tests/arrow/test_arrow_c_stream.py
new file mode 100644
index 000000000000..9534db71bae6
--- /dev/null
+++ b/python/pyspark/sql/tests/arrow/test_arrow_c_stream.py
@@ -0,0 +1,64 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import ctypes
+import unittest
+import pyarrow as pa
+import pandas as pd
+import pyspark.pandas as ps
+
+
+class TestSparkArrowCStreamer(unittest.TestCase):
+ def test_spark_arrow_c_streamer_arrow_consumer(self):
+ pdf = pd.DataFrame([[1, "a"], [2, "b"], [3, "c"], [4, "d"]],
columns=["id", "value"])
+ psdf = ps.from_pandas(pdf)
+
+ capsule = psdf.__arrow_c_stream__()
+ assert (
+ ctypes.pythonapi.PyCapsule_IsValid(ctypes.py_object(capsule),
b"arrow_array_stream")
+ == 1
+ )
+
+ stream = pa.RecordBatchReader.from_stream(psdf)
+ assert isinstance(stream, pa.RecordBatchReader)
+ result = pa.Table.from_batches(stream)
+ schema = pa.schema(
+ [
+ ("__index_level_0__", pa.int64(), False),
+ ("id", pa.int64(), False),
+ ("value", pa.string(), False),
+ ]
+ )
+ expected = pa.Table.from_pandas(
+ pd.DataFrame(
+ [[0, 1, "a"], [1, 2, "b"], [2, 3, "c"], [3, 4, "d"]],
+ columns=["__index_level_0__", "id", "value"],
+ ),
+ schema=schema,
+ )
+ self.assertEqual(result, expected)
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.arrow.test_arrow_c_stream import * # noqa: F401
+
+ try:
+ import xmlrunner # type: ignore
+
+ test_runner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ test_runner = None
+ unittest.main(testRunner=test_runner, verbosity=2)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]