This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b0b02b231aef [SPARK-48653][PYTHON] Fix invalid Python data source 
error class references
b0b02b231aef is described below

commit b0b02b231aefc2b11518ba2dacecb361429dafcb
Author: allisonwang-db <[email protected]>
AuthorDate: Fri Jun 21 09:07:40 2024 +0900

    [SPARK-48653][PYTHON] Fix invalid Python data source error class references
    
    ### What changes were proposed in this pull request?
    
    This PR fixes a few invalid error class references and adds more tests. Two 
error classes are invalid:
    - `PYTHON_DATA_SOURCE_TYPE_MISMATCH` -> `DATA_SOURCE_TYPE_MISMATCH`
    - `PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH` -> 
`DATA_SOURCE_RETURN_SCHEMA_MISMATCH`
    
    ### Why are the changes needed?
    
    To fix invalid error class references.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #47013 from allisonwang-db/spark-48653-fix-pyds-err-cls.
    
    Authored-by: allisonwang-db <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../streaming/python_streaming_source_runner.py    |  4 +-
 python/pyspark/sql/tests/test_python_datasource.py | 46 +++++++++++++++++++++-
 .../pyspark/sql/worker/commit_data_source_write.py | 11 +-----
 python/pyspark/sql/worker/create_data_source.py    |  8 ++--
 python/pyspark/sql/worker/plan_data_source_read.py | 41 +++++++++++++------
 .../sql/worker/python_streaming_sink_runner.py     |  6 +--
 .../pyspark/sql/worker/write_into_data_source.py   | 15 ++++++-
 7 files changed, 97 insertions(+), 34 deletions(-)

diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py 
b/python/pyspark/sql/streaming/python_streaming_source_runner.py
index 5292e2f92784..754ecff61b97 100644
--- a/python/pyspark/sql/streaming/python_streaming_source_runner.py
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -130,7 +130,7 @@ def main(infile: IO, outfile: IO) -> None:
 
         if not isinstance(data_source, DataSource):
             raise PySparkAssertionError(
-                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                error_class="DATA_SOURCE_TYPE_MISMATCH",
                 message_parameters={
                     "expected": "a Python data source instance of type 
'DataSource'",
                     "actual": f"'{type(data_source).__name__}'",
@@ -142,7 +142,7 @@ def main(infile: IO, outfile: IO) -> None:
         schema = _parse_datatype_json_string(schema_json)
         if not isinstance(schema, StructType):
             raise PySparkAssertionError(
-                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                error_class="DATA_SOURCE_TYPE_MISMATCH",
                 message_parameters={
                     "expected": "an output schema of type 'StructType'",
                     "actual": f"'{type(schema).__name__}'",
diff --git a/python/pyspark/sql/tests/test_python_datasource.py 
b/python/pyspark/sql/tests/test_python_datasource.py
index d028a210b007..8431e9b3e35d 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -19,7 +19,7 @@ import tempfile
 import unittest
 from typing import Callable, Union
 
-from pyspark.errors import PythonException
+from pyspark.errors import PythonException, AnalysisException
 from pyspark.sql.datasource import (
     DataSource,
     DataSourceReader,
@@ -154,7 +154,8 @@ class BasePythonDataSourceTestsMixin:
             read_func=lambda schema, partition: iter([Row(i=1, j=2), Row(j=3, 
k=4)])
         )
         with self.assertRaisesRegex(
-            PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH"
+            PythonException,
+            r"\[DATA_SOURCE_RETURN_SCHEMA_MISMATCH\] Return schema mismatch in 
the result",
         ):
             self.spark.read.format("test").load().show()
 
@@ -373,6 +374,47 @@ class BasePythonDataSourceTestsMixin:
         self.assertEqual(d2["BaR"], 3)
         self.assertEqual(d2["baz"], 3)
 
+    def test_data_source_type_mismatch(self):
+        class TestDataSource(DataSource):
+            @classmethod
+            def name(cls):
+                return "test"
+
+            def schema(self):
+                return "id int"
+
+            def reader(self, schema):
+                return TestReader()
+
+            def writer(self, schema, overwrite):
+                return TestWriter()
+
+        class TestReader:
+            def partitions(self):
+                return []
+
+            def read(self, partition):
+                yield (0,)
+
+        class TestWriter:
+            def write(self, iterator):
+                return WriterCommitMessage()
+
+        self.spark.dataSource.register(TestDataSource)
+
+        with self.assertRaisesRegex(
+            AnalysisException,
+            r"\[DATA_SOURCE_TYPE_MISMATCH\] Expected an instance of 
DataSourceReader",
+        ):
+            self.spark.read.format("test").load().show()
+
+        df = self.spark.range(10)
+        with self.assertRaisesRegex(
+            AnalysisException,
+            r"\[DATA_SOURCE_TYPE_MISMATCH\] Expected an instance of 
DataSourceWriter",
+        ):
+            df.write.format("test").mode("append").saveAsTable("test_table")
+
 
 class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
     ...
diff --git a/python/pyspark/sql/worker/commit_data_source_write.py 
b/python/pyspark/sql/worker/commit_data_source_write.py
index c7783df449d8..1d9e53083d4d 100644
--- a/python/pyspark/sql/worker/commit_data_source_write.py
+++ b/python/pyspark/sql/worker/commit_data_source_write.py
@@ -60,14 +60,7 @@ def main(infile: IO, outfile: IO) -> None:
 
         # Receive the data source writer instance.
         writer = pickleSer._read_with_length(infile)
-        if not isinstance(writer, DataSourceWriter):
-            raise PySparkAssertionError(
-                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
-                message_parameters={
-                    "expected": "an instance of DataSourceWriter",
-                    "actual": f"'{type(writer).__name__}'",
-                },
-            )
+        assert isinstance(writer, DataSourceWriter)
 
         # Receive the commit messages.
         num_messages = read_int(infile)
@@ -76,7 +69,7 @@ def main(infile: IO, outfile: IO) -> None:
             message = pickleSer._read_with_length(infile)
             if message is not None and not isinstance(message, 
WriterCommitMessage):
                 raise PySparkAssertionError(
-                    error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                    error_class="DATA_SOURCE_TYPE_MISMATCH",
                     message_parameters={
                         "expected": "an instance of WriterCommitMessage",
                         "actual": f"'{type(message).__name__}'",
diff --git a/python/pyspark/sql/worker/create_data_source.py 
b/python/pyspark/sql/worker/create_data_source.py
index 33394cdff876..d6b59b04393d 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -75,7 +75,7 @@ def main(infile: IO, outfile: IO) -> None:
         data_source_cls = read_command(pickleSer, infile)
         if not (isinstance(data_source_cls, type) and 
issubclass(data_source_cls, DataSource)):
             raise PySparkAssertionError(
-                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                error_class="DATA_SOURCE_TYPE_MISMATCH",
                 message_parameters={
                     "expected": "a subclass of DataSource",
                     "actual": f"'{type(data_source_cls).__name__}'",
@@ -85,7 +85,7 @@ def main(infile: IO, outfile: IO) -> None:
         # Check the name method is a class method.
         if not inspect.ismethod(data_source_cls.name):
             raise PySparkTypeError(
-                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                error_class="DATA_SOURCE_TYPE_MISMATCH",
                 message_parameters={
                     "expected": "'name()' method to be a classmethod",
                     "actual": f"'{type(data_source_cls.name).__name__}'",
@@ -98,7 +98,7 @@ def main(infile: IO, outfile: IO) -> None:
         # Check if the provider name matches the data source's name.
         if provider.lower() != data_source_cls.name().lower():
             raise PySparkAssertionError(
-                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                error_class="DATA_SOURCE_TYPE_MISMATCH",
                 message_parameters={
                     "expected": f"provider with name {data_source_cls.name()}",
                     "actual": f"'{provider}'",
@@ -111,7 +111,7 @@ def main(infile: IO, outfile: IO) -> None:
             user_specified_schema = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
             if not isinstance(user_specified_schema, StructType):
                 raise PySparkAssertionError(
-                    error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                    error_class="DATA_SOURCE_TYPE_MISMATCH",
                     message_parameters={
                         "expected": "the user-defined schema to be a 
'StructType'",
                         "actual": f"'{type(data_source_cls).__name__}'",
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py 
b/python/pyspark/sql/worker/plan_data_source_read.py
index be7ebd20f180..51a90bba1454 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -20,7 +20,7 @@ import sys
 import functools
 import pyarrow as pa
 from itertools import islice
-from typing import IO, List, Iterator, Iterable, Tuple
+from typing import IO, List, Iterator, Iterable, Tuple, Union
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
@@ -32,7 +32,12 @@ from pyspark.serializers import (
 )
 from pyspark.sql import Row
 from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, 
LocalDataToArrowConversion
-from pyspark.sql.datasource import DataSource, InputPartition
+from pyspark.sql.datasource import (
+    DataSource,
+    DataSourceReader,
+    DataSourceStreamReader,
+    InputPartition,
+)
 from pyspark.sql.datasource_internal import _streamReader
 from pyspark.sql.pandas.types import to_arrow_schema
 from pyspark.sql.types import (
@@ -108,7 +113,7 @@ def records_to_arrow_batches(
                 # Check if the names are the same as the schema.
                 if set(result.__fields__) != col_name_set:
                     raise PySparkRuntimeError(
-                        
error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH",
+                        error_class="DATA_SOURCE_RETURN_SCHEMA_MISMATCH",
                         message_parameters={
                             "expected": str(column_names),
                             "actual": str(result.__fields__),
@@ -187,7 +192,7 @@ def main(infile: IO, outfile: IO) -> None:
         schema = _parse_datatype_json_string(schema_json)
         if not isinstance(schema, StructType):
             raise PySparkAssertionError(
-                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                error_class="DATA_SOURCE_TYPE_MISMATCH",
                 message_parameters={
                     "expected": "an output schema of type 'StructType'",
                     "actual": f"'{type(schema).__name__}'",
@@ -204,11 +209,21 @@ def main(infile: IO, outfile: IO) -> None:
         is_streaming = read_bool(infile)
 
         # Instantiate data source reader.
-        reader = (
-            _streamReader(data_source, schema)
-            if is_streaming
-            else data_source.reader(schema=schema)
-        )
+        if is_streaming:
+            reader: Union[DataSourceReader, DataSourceStreamReader] = 
_streamReader(
+                data_source, schema
+            )
+        else:
+            reader = data_source.reader(schema=schema)
+            # Validate the reader.
+            if not isinstance(reader, DataSourceReader):
+                raise PySparkAssertionError(
+                    error_class="DATA_SOURCE_TYPE_MISMATCH",
+                    message_parameters={
+                        "expected": "an instance of DataSourceReader",
+                        "actual": f"'{type(reader).__name__}'",
+                    },
+                )
 
         # Create input converter.
         converter = ArrowTableToRowsConversion._create_converter(BinaryType())
@@ -241,7 +256,7 @@ def main(infile: IO, outfile: IO) -> None:
                 f"but found '{type(partition).__name__}'."
             )
 
-            output_iter = reader.read(partition)  # type: ignore[attr-defined]
+            output_iter = reader.read(partition)  # type: ignore[arg-type]
 
             # Validate the output iterator.
             if not isinstance(output_iter, Iterator):
@@ -264,7 +279,7 @@ def main(infile: IO, outfile: IO) -> None:
         if not is_streaming:
             # The partitioning of python batch source read is determined 
before query execution.
             try:
-                partitions = reader.partitions()  # type: ignore[attr-defined]
+                partitions = reader.partitions()  # type: ignore[call-arg]
                 if not isinstance(partitions, list):
                     raise PySparkRuntimeError(
                         error_class="DATA_SOURCE_TYPE_MISMATCH",
@@ -283,9 +298,9 @@ def main(infile: IO, outfile: IO) -> None:
                         },
                     )
                 if len(partitions) == 0:
-                    partitions = [None]
+                    partitions = [None]  # type: ignore[list-item]
             except NotImplementedError:
-                partitions = [None]
+                partitions = [None]  # type: ignore[list-item]
 
             # Return the serialized partition values.
             write_int(len(partitions), outfile)
diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py 
b/python/pyspark/sql/worker/python_streaming_sink_runner.py
index b84234b309f9..7d03157d705d 100644
--- a/python/pyspark/sql/worker/python_streaming_sink_runner.py
+++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py
@@ -70,7 +70,7 @@ def main(infile: IO, outfile: IO) -> None:
 
         if not isinstance(data_source, DataSource):
             raise PySparkAssertionError(
-                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                error_class="DATA_SOURCE_TYPE_MISMATCH",
                 message_parameters={
                     "expected": "a Python data source instance of type 
'DataSource'",
                     "actual": f"'{type(data_source).__name__}'",
@@ -81,7 +81,7 @@ def main(infile: IO, outfile: IO) -> None:
         schema = _parse_datatype_json_string(schema_json)
         if not isinstance(schema, StructType):
             raise PySparkAssertionError(
-                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                error_class="DATA_SOURCE_TYPE_MISMATCH",
                 message_parameters={
                     "expected": "an output schema of type 'StructType'",
                     "actual": f"'{type(schema).__name__}'",
@@ -101,7 +101,7 @@ def main(infile: IO, outfile: IO) -> None:
                 message = pickleSer._read_with_length(infile)
                 if message is not None and not isinstance(message, 
WriterCommitMessage):
                     raise PySparkAssertionError(
-                        error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                        error_class="DATA_SOURCE_TYPE_MISMATCH",
                         message_parameters={
                             "expected": "an instance of WriterCommitMessage",
                             "actual": f"'{type(message).__name__}'",
diff --git a/python/pyspark/sql/worker/write_into_data_source.py 
b/python/pyspark/sql/worker/write_into_data_source.py
index 5714f35cbe71..212a2754ec9f 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -29,7 +29,12 @@ from pyspark.serializers import (
     SpecialLengths,
 )
 from pyspark.sql import Row
-from pyspark.sql.datasource import DataSource, WriterCommitMessage, 
CaseInsensitiveDict
+from pyspark.sql.datasource import (
+    DataSource,
+    DataSourceWriter,
+    WriterCommitMessage,
+    CaseInsensitiveDict,
+)
 from pyspark.sql.types import (
     _parse_datatype_json_string,
     StructType,
@@ -162,6 +167,14 @@ def main(infile: IO, outfile: IO) -> None:
         else:
             # Instantiate the data source writer.
             writer = data_source.writer(schema, overwrite)  # type: 
ignore[assignment]
+            if not isinstance(writer, DataSourceWriter):
+                raise PySparkAssertionError(
+                    error_class="DATA_SOURCE_TYPE_MISMATCH",
+                    message_parameters={
+                        "expected": "an instance of DataSourceWriter",
+                        "actual": f"'{type(writer).__name__}'",
+                    },
+                )
 
         # Create a function that can be used in mapInArrow.
         import pyarrow as pa


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to