This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1b63af85de7b [SPARK-55851][PYTHON] Clarify types of datasource
partition and read
1b63af85de7b is described below
commit 1b63af85de7b593a6c4217f77208371566cbea19
Author: Tian Gao <[email protected]>
AuthorDate: Thu Mar 12 07:02:56 2026 +0900
[SPARK-55851][PYTHON] Clarify types of datasource partition and read
### What changes were proposed in this pull request?
Clarify and unify the types of datasource partition/read function.
* Correct type annotation
* Change the default behavior of `partition` to make it match the
documentation.
### Why are the changes needed?
Our current type hint for `partition` and `read` is wrong. We do accept
`None` as a partition in our code but we did not mention it. This will confuse
users as this is documented and user facing.
We also says the default behavior for `partition` is to return a list with
`None` - but we didn't do that. Instead, we used a very convoluted approach -
raise an exception, catch that in the worker and convert that to `[None]`.
That's super unnecessary.
This change should not affect users if they already overwrite their
`partition` and `read` method. (unless they overwrite with a function that
raises `NotImplementedError`).
Overall this gives us a more consistent interface with correct typing.
### Does this PR introduce _any_ user-facing change?
Basically no. Unless user does something crazy.
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54635 from gaogaotiantian/fix-datasource-read-type.
Authored-by: Tian Gao <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/datasource.py | 11 ++---
python/pyspark/sql/tests/test_python_datasource.py | 6 +--
python/pyspark/sql/worker/plan_data_source_read.py | 47 ++++++++++------------
.../execution/python/PythonDataSourceSuite.scala | 2 +-
4 files changed, 30 insertions(+), 36 deletions(-)
diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index bb73a7a9206b..554715034938 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -588,8 +588,8 @@ class DataSourceReader(ABC):
partition value to read the data.
This method is called once during query planning. By default, it
returns a
- single partition with the value ``None``. Subclasses can override this
method
- to return multiple partitions.
+ single partition with the value `InputPartition(None)`. Subclasses can
override
+ this method to return multiple partitions.
It's recommended to override this method for better performance when
reading
large datasets.
@@ -626,10 +626,7 @@ class DataSourceReader(ABC):
>>> def partitions(self):
... return [RangeInputPartition(1, 3), RangeInputPartition(5, 10)]
"""
- raise PySparkNotImplementedError(
- errorClass="NOT_IMPLEMENTED",
- messageParameters={"feature": "partitions"},
- )
+ return [InputPartition(None)]
@abstractmethod
def read(self, partition: InputPartition) -> Union[Iterator[Tuple],
Iterator["RecordBatch"]]:
@@ -643,7 +640,7 @@ class DataSourceReader(ABC):
Parameters
----------
- partition : object
+ partition : InputPartition
The partition to read. It must be one of the partition values
returned by
:meth:`DataSourceReader.partitions`.
diff --git a/python/pyspark/sql/tests/test_python_datasource.py
b/python/pyspark/sql/tests/test_python_datasource.py
index 9d90082c654d..e8776c288720 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -162,7 +162,7 @@ class BasePythonDataSourceTestsMixin:
if partition_func is not None:
return partition_func()
else:
- raise NotImplementedError
+ return [InputPartition(None)]
def read(self, partition):
return read_func(self.schema, partition)
@@ -1040,7 +1040,7 @@ class BasePythonDataSourceTestsMixin:
{"class_name": "TestJsonReader", "func_name":
"partitions"},
),
(
- "TestJsonReader.read: None",
+ "TestJsonReader.read: InputPartition(value=None)",
{"class_name": "TestJsonReader", "func_name":
"read"},
),
]
@@ -1151,7 +1151,7 @@ class BasePythonDataSourceTestsMixin:
{"class_name": "TestJsonReader", "func_name":
"partitions"},
),
(
- "TestJsonReader.read: None",
+ "TestJsonReader.read: InputPartition(value=None)",
{"class_name": "TestJsonReader", "func_name":
"read"},
),
]
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py
b/python/pyspark/sql/worker/plan_data_source_read.py
index cded42031b0c..d736df6084c6 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -210,12 +210,12 @@ def write_read_func_and_partitions(
# Deserialize the partition value.
partition = pickleSer.loads(partition_bytes)
- assert partition is None or isinstance(partition, InputPartition), (
+ assert isinstance(partition, InputPartition), (
"Expected the partition value to be of type 'InputPartition', "
f"but found '{type(partition).__name__}'."
)
- output_iter = reader.read(partition) # type: ignore[arg-type]
+ output_iter = reader.read(partition)
# Validate the output iterator.
if not isinstance(output_iter, Iterator):
@@ -240,29 +240,26 @@ def write_read_func_and_partitions(
if not is_streaming:
# The partitioning of python batch source read is determined before
query execution.
- try:
- partitions = reader.partitions() # type: ignore[call-arg]
- if not isinstance(partitions, list):
- raise PySparkRuntimeError(
- errorClass="DATA_SOURCE_TYPE_MISMATCH",
- messageParameters={
- "expected": "'partitions' to return a list",
- "actual": f"'{type(partitions).__name__}'",
- },
- )
- if not all(isinstance(p, InputPartition) for p in partitions):
- partition_types = ", ".join([f"'{type(p).__name__}'" for p in
partitions])
- raise PySparkRuntimeError(
- errorClass="DATA_SOURCE_TYPE_MISMATCH",
- messageParameters={
- "expected": "elements in 'partitions' to be of type
'InputPartition'",
- "actual": partition_types,
- },
- )
- if len(partitions) == 0:
- partitions = [None] # type: ignore[list-item]
- except NotImplementedError:
- partitions = [None] # type: ignore[list-item]
+ partitions = reader.partitions() # type: ignore[call-arg]
+ if not isinstance(partitions, list):
+ raise PySparkRuntimeError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": "'partitions' to return a list",
+ "actual": f"'{type(partitions).__name__}'",
+ },
+ )
+ if not all(isinstance(p, InputPartition) for p in partitions):
+ partition_types = ", ".join([f"'{type(p).__name__}'" for p in
partitions])
+ raise PySparkRuntimeError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": "elements in 'partitions' to be of type
'InputPartition'",
+ "actual": partition_types,
+ },
+ )
+ if len(partitions) == 0:
+ partitions = [InputPartition(None)]
# Return the serialized partition values.
write_int(len(partitions), outfile)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index bac9849381a3..a712607f18e4 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -533,7 +533,7 @@ class PythonDataSourceSuite extends
PythonDataSourceSuiteBase {
| return []
|
| def read(self, partition):
- | if partition is None:
+ | if partition.value is None:
| yield ("success", )
| else:
| yield ("failed", )
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]