This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b0b02b231aef [SPARK-48653][PYTHON] Fix invalid Python data source
error class references
b0b02b231aef is described below
commit b0b02b231aefc2b11518ba2dacecb361429dafcb
Author: allisonwang-db <[email protected]>
AuthorDate: Fri Jun 21 09:07:40 2024 +0900
[SPARK-48653][PYTHON] Fix invalid Python data source error class references
### What changes were proposed in this pull request?
This PR fixes a few invalid error class references and adds more tests. Two
error classes are invalid:
- `PYTHON_DATA_SOURCE_TYPE_MISMATCH` -> `DATA_SOURCE_TYPE_MISMATCH`
- `PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH` ->
`DATA_SOURCE_RETURN_SCHEMA_MISMATCH`
### Why are the changes needed?
To fix invalid error class references.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #47013 from allisonwang-db/spark-48653-fix-pyds-err-cls.
Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../streaming/python_streaming_source_runner.py | 4 +-
python/pyspark/sql/tests/test_python_datasource.py | 46 +++++++++++++++++++++-
.../pyspark/sql/worker/commit_data_source_write.py | 11 +-----
python/pyspark/sql/worker/create_data_source.py | 8 ++--
python/pyspark/sql/worker/plan_data_source_read.py | 41 +++++++++++++------
.../sql/worker/python_streaming_sink_runner.py | 6 +--
.../pyspark/sql/worker/write_into_data_source.py | 15 ++++++-
7 files changed, 97 insertions(+), 34 deletions(-)
diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py
b/python/pyspark/sql/streaming/python_streaming_source_runner.py
index 5292e2f92784..754ecff61b97 100644
--- a/python/pyspark/sql/streaming/python_streaming_source_runner.py
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -130,7 +130,7 @@ def main(infile: IO, outfile: IO) -> None:
if not isinstance(data_source, DataSource):
raise PySparkAssertionError(
- error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "a Python data source instance of type
'DataSource'",
"actual": f"'{type(data_source).__name__}'",
@@ -142,7 +142,7 @@ def main(infile: IO, outfile: IO) -> None:
schema = _parse_datatype_json_string(schema_json)
if not isinstance(schema, StructType):
raise PySparkAssertionError(
- error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an output schema of type 'StructType'",
"actual": f"'{type(schema).__name__}'",
diff --git a/python/pyspark/sql/tests/test_python_datasource.py
b/python/pyspark/sql/tests/test_python_datasource.py
index d028a210b007..8431e9b3e35d 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -19,7 +19,7 @@ import tempfile
import unittest
from typing import Callable, Union
-from pyspark.errors import PythonException
+from pyspark.errors import PythonException, AnalysisException
from pyspark.sql.datasource import (
DataSource,
DataSourceReader,
@@ -154,7 +154,8 @@ class BasePythonDataSourceTestsMixin:
read_func=lambda schema, partition: iter([Row(i=1, j=2), Row(j=3,
k=4)])
)
with self.assertRaisesRegex(
- PythonException, "PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH"
+ PythonException,
+ r"\[DATA_SOURCE_RETURN_SCHEMA_MISMATCH\] Return schema mismatch in
the result",
):
self.spark.read.format("test").load().show()
@@ -373,6 +374,47 @@ class BasePythonDataSourceTestsMixin:
self.assertEqual(d2["BaR"], 3)
self.assertEqual(d2["baz"], 3)
+ def test_data_source_type_mismatch(self):
+ class TestDataSource(DataSource):
+ @classmethod
+ def name(cls):
+ return "test"
+
+ def schema(self):
+ return "id int"
+
+ def reader(self, schema):
+ return TestReader()
+
+ def writer(self, schema, overwrite):
+ return TestWriter()
+
+ class TestReader:
+ def partitions(self):
+ return []
+
+ def read(self, partition):
+ yield (0,)
+
+ class TestWriter:
+ def write(self, iterator):
+ return WriterCommitMessage()
+
+ self.spark.dataSource.register(TestDataSource)
+
+ with self.assertRaisesRegex(
+ AnalysisException,
+ r"\[DATA_SOURCE_TYPE_MISMATCH\] Expected an instance of
DataSourceReader",
+ ):
+ self.spark.read.format("test").load().show()
+
+ df = self.spark.range(10)
+ with self.assertRaisesRegex(
+ AnalysisException,
+ r"\[DATA_SOURCE_TYPE_MISMATCH\] Expected an instance of
DataSourceWriter",
+ ):
+ df.write.format("test").mode("append").saveAsTable("test_table")
+
class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
...
diff --git a/python/pyspark/sql/worker/commit_data_source_write.py
b/python/pyspark/sql/worker/commit_data_source_write.py
index c7783df449d8..1d9e53083d4d 100644
--- a/python/pyspark/sql/worker/commit_data_source_write.py
+++ b/python/pyspark/sql/worker/commit_data_source_write.py
@@ -60,14 +60,7 @@ def main(infile: IO, outfile: IO) -> None:
# Receive the data source writer instance.
writer = pickleSer._read_with_length(infile)
- if not isinstance(writer, DataSourceWriter):
- raise PySparkAssertionError(
- error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
- message_parameters={
- "expected": "an instance of DataSourceWriter",
- "actual": f"'{type(writer).__name__}'",
- },
- )
+ assert isinstance(writer, DataSourceWriter)
# Receive the commit messages.
num_messages = read_int(infile)
@@ -76,7 +69,7 @@ def main(infile: IO, outfile: IO) -> None:
message = pickleSer._read_with_length(infile)
if message is not None and not isinstance(message,
WriterCommitMessage):
raise PySparkAssertionError(
- error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an instance of WriterCommitMessage",
"actual": f"'{type(message).__name__}'",
diff --git a/python/pyspark/sql/worker/create_data_source.py
b/python/pyspark/sql/worker/create_data_source.py
index 33394cdff876..d6b59b04393d 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -75,7 +75,7 @@ def main(infile: IO, outfile: IO) -> None:
data_source_cls = read_command(pickleSer, infile)
if not (isinstance(data_source_cls, type) and
issubclass(data_source_cls, DataSource)):
raise PySparkAssertionError(
- error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "a subclass of DataSource",
"actual": f"'{type(data_source_cls).__name__}'",
@@ -85,7 +85,7 @@ def main(infile: IO, outfile: IO) -> None:
# Check the name method is a class method.
if not inspect.ismethod(data_source_cls.name):
raise PySparkTypeError(
- error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "'name()' method to be a classmethod",
"actual": f"'{type(data_source_cls.name).__name__}'",
@@ -98,7 +98,7 @@ def main(infile: IO, outfile: IO) -> None:
# Check if the provider name matches the data source's name.
if provider.lower() != data_source_cls.name().lower():
raise PySparkAssertionError(
- error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": f"provider with name {data_source_cls.name()}",
"actual": f"'{provider}'",
@@ -111,7 +111,7 @@ def main(infile: IO, outfile: IO) -> None:
user_specified_schema =
_parse_datatype_json_string(utf8_deserializer.loads(infile))
if not isinstance(user_specified_schema, StructType):
raise PySparkAssertionError(
- error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "the user-defined schema to be a
'StructType'",
"actual": f"'{type(data_source_cls).__name__}'",
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py
b/python/pyspark/sql/worker/plan_data_source_read.py
index be7ebd20f180..51a90bba1454 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -20,7 +20,7 @@ import sys
import functools
import pyarrow as pa
from itertools import islice
-from typing import IO, List, Iterator, Iterable, Tuple
+from typing import IO, List, Iterator, Iterable, Tuple, Union
from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
@@ -32,7 +32,12 @@ from pyspark.serializers import (
)
from pyspark.sql import Row
from pyspark.sql.connect.conversion import ArrowTableToRowsConversion,
LocalDataToArrowConversion
-from pyspark.sql.datasource import DataSource, InputPartition
+from pyspark.sql.datasource import (
+ DataSource,
+ DataSourceReader,
+ DataSourceStreamReader,
+ InputPartition,
+)
from pyspark.sql.datasource_internal import _streamReader
from pyspark.sql.pandas.types import to_arrow_schema
from pyspark.sql.types import (
@@ -108,7 +113,7 @@ def records_to_arrow_batches(
# Check if the names are the same as the schema.
if set(result.__fields__) != col_name_set:
raise PySparkRuntimeError(
-
error_class="PYTHON_DATA_SOURCE_READ_RETURN_SCHEMA_MISMATCH",
+ error_class="DATA_SOURCE_RETURN_SCHEMA_MISMATCH",
message_parameters={
"expected": str(column_names),
"actual": str(result.__fields__),
@@ -187,7 +192,7 @@ def main(infile: IO, outfile: IO) -> None:
schema = _parse_datatype_json_string(schema_json)
if not isinstance(schema, StructType):
raise PySparkAssertionError(
- error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an output schema of type 'StructType'",
"actual": f"'{type(schema).__name__}'",
@@ -204,11 +209,21 @@ def main(infile: IO, outfile: IO) -> None:
is_streaming = read_bool(infile)
# Instantiate data source reader.
- reader = (
- _streamReader(data_source, schema)
- if is_streaming
- else data_source.reader(schema=schema)
- )
+ if is_streaming:
+ reader: Union[DataSourceReader, DataSourceStreamReader] =
_streamReader(
+ data_source, schema
+ )
+ else:
+ reader = data_source.reader(schema=schema)
+ # Validate the reader.
+ if not isinstance(reader, DataSourceReader):
+ raise PySparkAssertionError(
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
+ message_parameters={
+ "expected": "an instance of DataSourceReader",
+ "actual": f"'{type(reader).__name__}'",
+ },
+ )
# Create input converter.
converter = ArrowTableToRowsConversion._create_converter(BinaryType())
@@ -241,7 +256,7 @@ def main(infile: IO, outfile: IO) -> None:
f"but found '{type(partition).__name__}'."
)
- output_iter = reader.read(partition) # type: ignore[attr-defined]
+ output_iter = reader.read(partition) # type: ignore[arg-type]
# Validate the output iterator.
if not isinstance(output_iter, Iterator):
@@ -264,7 +279,7 @@ def main(infile: IO, outfile: IO) -> None:
if not is_streaming:
# The partitioning of python batch source read is determined
before query execution.
try:
- partitions = reader.partitions() # type: ignore[attr-defined]
+ partitions = reader.partitions() # type: ignore[call-arg]
if not isinstance(partitions, list):
raise PySparkRuntimeError(
error_class="DATA_SOURCE_TYPE_MISMATCH",
@@ -283,9 +298,9 @@ def main(infile: IO, outfile: IO) -> None:
},
)
if len(partitions) == 0:
- partitions = [None]
+ partitions = [None] # type: ignore[list-item]
except NotImplementedError:
- partitions = [None]
+ partitions = [None] # type: ignore[list-item]
# Return the serialized partition values.
write_int(len(partitions), outfile)
diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py
b/python/pyspark/sql/worker/python_streaming_sink_runner.py
index b84234b309f9..7d03157d705d 100644
--- a/python/pyspark/sql/worker/python_streaming_sink_runner.py
+++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py
@@ -70,7 +70,7 @@ def main(infile: IO, outfile: IO) -> None:
if not isinstance(data_source, DataSource):
raise PySparkAssertionError(
- error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "a Python data source instance of type
'DataSource'",
"actual": f"'{type(data_source).__name__}'",
@@ -81,7 +81,7 @@ def main(infile: IO, outfile: IO) -> None:
schema = _parse_datatype_json_string(schema_json)
if not isinstance(schema, StructType):
raise PySparkAssertionError(
- error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an output schema of type 'StructType'",
"actual": f"'{type(schema).__name__}'",
@@ -101,7 +101,7 @@ def main(infile: IO, outfile: IO) -> None:
message = pickleSer._read_with_length(infile)
if message is not None and not isinstance(message,
WriterCommitMessage):
raise PySparkAssertionError(
- error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "an instance of WriterCommitMessage",
"actual": f"'{type(message).__name__}'",
diff --git a/python/pyspark/sql/worker/write_into_data_source.py
b/python/pyspark/sql/worker/write_into_data_source.py
index 5714f35cbe71..212a2754ec9f 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -29,7 +29,12 @@ from pyspark.serializers import (
SpecialLengths,
)
from pyspark.sql import Row
-from pyspark.sql.datasource import DataSource, WriterCommitMessage,
CaseInsensitiveDict
+from pyspark.sql.datasource import (
+ DataSource,
+ DataSourceWriter,
+ WriterCommitMessage,
+ CaseInsensitiveDict,
+)
from pyspark.sql.types import (
_parse_datatype_json_string,
StructType,
@@ -162,6 +167,14 @@ def main(infile: IO, outfile: IO) -> None:
else:
# Instantiate the data source writer.
writer = data_source.writer(schema, overwrite) # type:
ignore[assignment]
+ if not isinstance(writer, DataSourceWriter):
+ raise PySparkAssertionError(
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
+ message_parameters={
+ "expected": "an instance of DataSourceWriter",
+ "actual": f"'{type(writer).__name__}'",
+ },
+ )
# Create a function that can be used in mapInArrow.
import pyarrow as pa
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]