This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1b63af85de7b [SPARK-55851][PYTHON] Clarify types of datasource 
partition and read
1b63af85de7b is described below

commit 1b63af85de7b593a6c4217f77208371566cbea19
Author: Tian Gao <[email protected]>
AuthorDate: Thu Mar 12 07:02:56 2026 +0900

    [SPARK-55851][PYTHON] Clarify types of datasource partition and read
    
    ### What changes were proposed in this pull request?
    
    Clarify and unify the types of datasource partition/read function.
    * Correct type annotation
    * Change the default behavior of `partition` to make it match the 
documentation.
    
    ### Why are the changes needed?
    
    Our current type hint for `partition` and `read` is wrong. We do accept 
`None` as a partition in our code but we did not mention it. This will confuse 
users as this is documented and user facing.
    
    We also says the default behavior for `partition` is to return a list with 
`None` - but we didn't do that. Instead, we used a very convoluted approach - 
raise an exception, catch that in the worker and convert that to `[None]`. 
That's super unnecessary.
    
    This change should not affect users if they already overwrite their 
`partition` and `read` method. (unless they overwrite with a function that 
raises `NotImplementedError`).
    
    Overall this gives us a more consistent interface with correct typing.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Basically no. Unless user does something crazy.
    
    ### How was this patch tested?
    
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #54635 from gaogaotiantian/fix-datasource-read-type.
    
    Authored-by: Tian Gao <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/datasource.py                   | 11 ++---
 python/pyspark/sql/tests/test_python_datasource.py |  6 +--
 python/pyspark/sql/worker/plan_data_source_read.py | 47 ++++++++++------------
 .../execution/python/PythonDataSourceSuite.scala   |  2 +-
 4 files changed, 30 insertions(+), 36 deletions(-)

diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index bb73a7a9206b..554715034938 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -588,8 +588,8 @@ class DataSourceReader(ABC):
         partition value to read the data.
 
         This method is called once during query planning. By default, it 
returns a
-        single partition with the value ``None``. Subclasses can override this 
method
-        to return multiple partitions.
+        single partition with the value `InputPartition(None)`. Subclasses can 
override
+        this method to return multiple partitions.
 
         It's recommended to override this method for better performance when 
reading
         large datasets.
@@ -626,10 +626,7 @@ class DataSourceReader(ABC):
         >>> def partitions(self):
         ...     return [RangeInputPartition(1, 3), RangeInputPartition(5, 10)]
         """
-        raise PySparkNotImplementedError(
-            errorClass="NOT_IMPLEMENTED",
-            messageParameters={"feature": "partitions"},
-        )
+        return [InputPartition(None)]
 
     @abstractmethod
     def read(self, partition: InputPartition) -> Union[Iterator[Tuple], 
Iterator["RecordBatch"]]:
@@ -643,7 +640,7 @@ class DataSourceReader(ABC):
 
         Parameters
         ----------
-        partition : object
+        partition : InputPartition
             The partition to read. It must be one of the partition values 
returned by
             :meth:`DataSourceReader.partitions`.
 
diff --git a/python/pyspark/sql/tests/test_python_datasource.py 
b/python/pyspark/sql/tests/test_python_datasource.py
index 9d90082c654d..e8776c288720 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -162,7 +162,7 @@ class BasePythonDataSourceTestsMixin:
                 if partition_func is not None:
                     return partition_func()
                 else:
-                    raise NotImplementedError
+                    return [InputPartition(None)]
 
             def read(self, partition):
                 return read_func(self.schema, partition)
@@ -1040,7 +1040,7 @@ class BasePythonDataSourceTestsMixin:
                             {"class_name": "TestJsonReader", "func_name": 
"partitions"},
                         ),
                         (
-                            "TestJsonReader.read: None",
+                            "TestJsonReader.read: InputPartition(value=None)",
                             {"class_name": "TestJsonReader", "func_name": 
"read"},
                         ),
                     ]
@@ -1151,7 +1151,7 @@ class BasePythonDataSourceTestsMixin:
                             {"class_name": "TestJsonReader", "func_name": 
"partitions"},
                         ),
                         (
-                            "TestJsonReader.read: None",
+                            "TestJsonReader.read: InputPartition(value=None)",
                             {"class_name": "TestJsonReader", "func_name": 
"read"},
                         ),
                     ]
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py 
b/python/pyspark/sql/worker/plan_data_source_read.py
index cded42031b0c..d736df6084c6 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -210,12 +210,12 @@ def write_read_func_and_partitions(
         # Deserialize the partition value.
         partition = pickleSer.loads(partition_bytes)
 
-        assert partition is None or isinstance(partition, InputPartition), (
+        assert isinstance(partition, InputPartition), (
             "Expected the partition value to be of type 'InputPartition', "
             f"but found '{type(partition).__name__}'."
         )
 
-        output_iter = reader.read(partition)  # type: ignore[arg-type]
+        output_iter = reader.read(partition)
 
         # Validate the output iterator.
         if not isinstance(output_iter, Iterator):
@@ -240,29 +240,26 @@ def write_read_func_and_partitions(
 
     if not is_streaming:
         # The partitioning of python batch source read is determined before 
query execution.
-        try:
-            partitions = reader.partitions()  # type: ignore[call-arg]
-            if not isinstance(partitions, list):
-                raise PySparkRuntimeError(
-                    errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                    messageParameters={
-                        "expected": "'partitions' to return a list",
-                        "actual": f"'{type(partitions).__name__}'",
-                    },
-                )
-            if not all(isinstance(p, InputPartition) for p in partitions):
-                partition_types = ", ".join([f"'{type(p).__name__}'" for p in 
partitions])
-                raise PySparkRuntimeError(
-                    errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                    messageParameters={
-                        "expected": "elements in 'partitions' to be of type 
'InputPartition'",
-                        "actual": partition_types,
-                    },
-                )
-            if len(partitions) == 0:
-                partitions = [None]  # type: ignore[list-item]
-        except NotImplementedError:
-            partitions = [None]  # type: ignore[list-item]
+        partitions = reader.partitions()  # type: ignore[call-arg]
+        if not isinstance(partitions, list):
+            raise PySparkRuntimeError(
+                errorClass="DATA_SOURCE_TYPE_MISMATCH",
+                messageParameters={
+                    "expected": "'partitions' to return a list",
+                    "actual": f"'{type(partitions).__name__}'",
+                },
+            )
+        if not all(isinstance(p, InputPartition) for p in partitions):
+            partition_types = ", ".join([f"'{type(p).__name__}'" for p in 
partitions])
+            raise PySparkRuntimeError(
+                errorClass="DATA_SOURCE_TYPE_MISMATCH",
+                messageParameters={
+                    "expected": "elements in 'partitions' to be of type 
'InputPartition'",
+                    "actual": partition_types,
+                },
+            )
+        if len(partitions) == 0:
+            partitions = [InputPartition(None)]
 
         # Return the serialized partition values.
         write_int(len(partitions), outfile)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index bac9849381a3..a712607f18e4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -533,7 +533,7 @@ class PythonDataSourceSuite extends 
PythonDataSourceSuiteBase {
          |        return []
          |
          |    def read(self, partition):
-         |        if partition is None:
+         |        if partition.value is None:
          |            yield ("success", )
          |        else:
          |            yield ("failed", )


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to