This is an automated email from the ASF dual-hosted git repository.
ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 8e1e126c68aa [SPARK-55121][PYTHON][SS] Add DataStreamReader.name() to
Classic PySpark
8e1e126c68aa is described below
commit 8e1e126c68aaa94f0fa6861b93042e1385615418
Author: ericm-db <[email protected]>
AuthorDate: Fri Jan 23 10:08:31 2026 -0800
[SPARK-55121][PYTHON][SS] Add DataStreamReader.name() to Classic PySpark
### What changes were proposed in this pull request?
This PR adds the `name()` method to Classic PySpark's `DataStreamReader`
class. This method allows users to specify a name for streaming sources, which
is used in checkpoint metadata and enables stable checkpoint locations for
source evolution.
Changes include:
- Add `name()` method to `DataStreamReader` in
`python/pyspark/sql/streaming/readwriter.py`
- Add comprehensive test suite in
`python/pyspark/sql/tests/streaming/test_streaming_reader_name.py`
- Update compatibility test to mark `name` as currently missing from
Connect (until the Connect PR merges)
The method validates that the source_name contains only ASCII letters,
digits, and underscores, raising `PySparkTypeError` or `PySparkValueError` for
invalid inputs.
### Why are the changes needed?
This brings Classic PySpark to feature parity with the Scala/Java API for
streaming source naming. The `name()` method is essential for:
1. Identifying sources in checkpoint metadata
2. Enabling stable checkpoint locations during source evolution
3. Providing consistency across Classic and Connect implementations
### Does this PR introduce _any_ user-facing change?
Yes. Users can now call `.name()` on DataStreamReader in Classic PySpark:
```python
spark.readStream.format("parquet").name("my_source").load("/path")
```
### How was this patch tested?
- Added comprehensive unit tests in `test_streaming_reader_name.py`
covering:
- Valid name patterns (letters, digits, underscores)
- Invalid names (hyphens, spaces, dots, special characters, empty
strings, None, wrong types)
- Method chaining
- Different data formats (parquet, json)
- Integration with streaming queries
- Updated compatibility tests to account for the current state where
Classic has `name` but Connect doesn't yet
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53898 from ericm-db/classic-datastream-reader-name.
Authored-by: ericm-db <[email protected]>
Signed-off-by: Anish Shrigondekar <[email protected]>
---
dev/sparktestsupport/modules.py | 1 +
python/pyspark/errors/error-conditions.json | 5 +
python/pyspark/sql/streaming/readwriter.py | 47 ++++++
.../tests/streaming/test_streaming_reader_name.py | 179 +++++++++++++++++++++
.../sql/tests/test_connect_compatibility.py | 2 +-
5 files changed, 233 insertions(+), 1 deletion(-)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 5cf0f745a151..6ba2e619f703 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -668,6 +668,7 @@ pyspark_structured_streaming = Module(
"pyspark.sql.tests.streaming.test_streaming_foreach",
"pyspark.sql.tests.streaming.test_streaming_foreach_batch",
"pyspark.sql.tests.streaming.test_streaming_listener",
+ "pyspark.sql.tests.streaming.test_streaming_reader_name",
"pyspark.sql.tests.pandas.test_pandas_grouped_map_with_state",
"pyspark.sql.tests.pandas.streaming.test_pandas_transform_with_state",
"pyspark.sql.tests.pandas.streaming.test_pandas_transform_with_state_checkpoint_v2",
diff --git a/python/pyspark/errors/error-conditions.json
b/python/pyspark/errors/error-conditions.json
index 7943c0992603..ee35e237b898 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -465,6 +465,11 @@
"Parameter value <arg_name> must be a valid UUID format: <origin>"
]
},
+ "INVALID_STREAMING_SOURCE_NAME": {
+ "message": [
+ "Invalid streaming source name '<source_name>'. Source names must
contain only ASCII letters, digits, and underscores."
+ ]
+ },
"INVALID_TIMEOUT_TIMESTAMP": {
"message": [
"Timeout timestamp (<timestamp>) cannot be earlier than the current
watermark (<watermark>)."
diff --git a/python/pyspark/sql/streaming/readwriter.py
b/python/pyspark/sql/streaming/readwriter.py
index ffb4415eaac3..50c28f49ea1f 100644
--- a/python/pyspark/sql/streaming/readwriter.py
+++ b/python/pyspark/sql/streaming/readwriter.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import re
import sys
from collections.abc import Iterator
from typing import cast, overload, Any, Callable, List, Optional,
TYPE_CHECKING, Union
@@ -241,6 +242,52 @@ class DataStreamReader(OptionUtils):
self._jreader = self._jreader.option(k, to_str(options[k]))
return self
+ def name(self, source_name: str) -> "DataStreamReader":
+ """Specifies a name for the streaming source.
+
+ This name is used to identify the source in checkpoint metadata and
enables
+ stable checkpoint locations for source evolution.
+
+ .. versionadded:: 4.2.0
+
+ Parameters
+ ----------
+ source_name : str
+ the name to assign to this streaming source. Must contain only
ASCII letters,
+ digits, and underscores.
+
+ Returns
+ -------
+ :class:`DataStreamReader`
+
+ Notes
+ -----
+ This API is experimental.
+
+ Examples
+ --------
+ >>> spark.readStream.format("rate").name("my_source") # doctest: +SKIP
+ <...streaming.readwriter.DataStreamReader object ...>
+ """
+ if not isinstance(source_name, str):
+ raise PySparkTypeError(
+ errorClass="NOT_STR",
+ messageParameters={
+ "arg_name": "source_name",
+ "arg_type": type(source_name).__name__,
+ },
+ )
+
+ # Validate that source_name contains only ASCII letters, digits, and
underscores
+ if not re.match(r"^[a-zA-Z0-9_]+$", source_name):
+ raise PySparkValueError(
+ errorClass="INVALID_STREAMING_SOURCE_NAME",
+ messageParameters={"source_name": source_name},
+ )
+
+ self._jreader = self._jreader.name(source_name)
+ return self
+
def load(
self,
path: Optional[str] = None,
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_reader_name.py
b/python/pyspark/sql/tests/streaming/test_streaming_reader_name.py
new file mode 100644
index 000000000000..8b3af847df7e
--- /dev/null
+++ b/python/pyspark/sql/tests/streaming/test_streaming_reader_name.py
@@ -0,0 +1,179 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import time
+
+from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class DataStreamReaderNameTests(ReusedSQLTestCase):
+ """Test suite for DataStreamReader.name() functionality in PySpark."""
+
+ @classmethod
+ def setUpClass(cls):
+ super(DataStreamReaderNameTests, cls).setUpClass()
+ # Enable streaming source evolution feature
+
cls.spark.conf.set("spark.sql.streaming.queryEvolution.enableSourceEvolution",
"true")
+ cls.spark.conf.set("spark.sql.streaming.offsetLog.formatVersion", "2")
+
+ def test_name_with_valid_names(self):
+ """Test that various valid source name patterns work correctly."""
+ valid_names = [
+ "mySource",
+ "my_source",
+ "MySource123",
+ "_private",
+ "source_123_test",
+ "123source",
+ ]
+
+ for name in valid_names:
+ with tempfile.TemporaryDirectory(prefix=f"test_{name}_") as tmpdir:
+ self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+ df = (
+ self.spark.readStream.format("parquet")
+ .schema("id LONG")
+ .name(name)
+ .load(tmpdir)
+ )
+ self.assertTrue(df.isStreaming, f"DataFrame should be
streaming for name: {name}")
+
+ def test_name_method_chaining(self):
+ """Test that name() returns the reader for method chaining."""
+ with tempfile.TemporaryDirectory(prefix="test_chaining_") as tmpdir:
+ self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+ df = (
+ self.spark.readStream.format("parquet")
+ .schema("id LONG")
+ .name("my_source")
+ .option("maxFilesPerTrigger", "1")
+ .load(tmpdir)
+ )
+
+ self.assertTrue(df.isStreaming, "DataFrame should be streaming")
+
+ def test_name_before_format(self):
+ """Test that order doesn't matter - name can be set before format."""
+ with tempfile.TemporaryDirectory(prefix="test_before_format_") as
tmpdir:
+ self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+ df = (
+ self.spark.readStream.name("my_source")
+ .format("parquet")
+ .schema("id LONG")
+ .load(tmpdir)
+ )
+
+ self.assertTrue(df.isStreaming, "DataFrame should be streaming")
+
+ def test_invalid_names(self):
+ """Test that various invalid source names are rejected."""
+ invalid_names = [
+ "", # empty string
+ " ", # whitespace only
+ "my-source", # hyphen
+ "my source", # space
+ "my.source", # dot
+ "my@source", # special char
+ "my$source", # dollar sign
+ "my#source", # hash
+ "my!source", # exclamation
+ ]
+
+ for invalid_name in invalid_names:
+ with self.subTest(name=invalid_name):
+ with tempfile.TemporaryDirectory(prefix="test_invalid_") as
tmpdir:
+
self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+ with self.assertRaises(PySparkValueError) as context:
+ self.spark.readStream.format("parquet").schema("id
LONG").name(
+ invalid_name
+ ).load(tmpdir)
+
+ # The error message should contain information about
invalid name
+ self.assertIn("source", str(context.exception).lower())
+
+ def test_invalid_name_wrong_type(self):
+ """Test that None and non-string types are rejected."""
+ invalid_types = [None, 123, 45.67, [], {}]
+
+ for invalid_value in invalid_types:
+ with self.subTest(value=invalid_value):
+ with self.assertRaises(PySparkTypeError):
+
self.spark.readStream.format("rate").name(invalid_value).load()
+
+ def test_name_with_different_formats(self):
+ """Test that name() works with different streaming data sources."""
+ with tempfile.TemporaryDirectory(prefix="test_name_formats_") as
tmpdir:
+ # Create test data
+ self.spark.range(10).write.mode("overwrite").parquet(tmpdir +
"/parquet_data")
+ self.spark.range(10).selectExpr("id", "CAST(id AS STRING) as
value").write.mode(
+ "overwrite"
+ ).json(tmpdir + "/json_data")
+
+ # Test with parquet
+ parquet_df = (
+ self.spark.readStream.format("parquet")
+ .name("parquet_source")
+ .schema("id LONG")
+ .load(tmpdir + "/parquet_data")
+ )
+ self.assertTrue(parquet_df.isStreaming, "Parquet DataFrame should
be streaming")
+
+ # Test with json - specify schema
+ json_df = (
+ self.spark.readStream.format("json")
+ .name("json_source")
+ .schema("id LONG, value STRING")
+ .load(tmpdir + "/json_data")
+ )
+ self.assertTrue(json_df.isStreaming, "JSON DataFrame should be
streaming")
+
+ def test_name_persists_through_query(self):
+ """Test that the name persists when starting a streaming query."""
+ with tempfile.TemporaryDirectory(prefix="test_name_query_") as tmpdir:
+ data_dir = tmpdir + "/data"
+ checkpoint_dir = tmpdir + "/checkpoint"
+
+ # Create test data
+ self.spark.range(10).write.mode("overwrite").parquet(data_dir)
+
+ df = (
+ self.spark.readStream.format("parquet")
+ .schema("id LONG")
+ .name("parquet_source_test")
+ .load(data_dir)
+ )
+
+ query = (
+ df.writeStream.format("noop").option("checkpointLocation",
checkpoint_dir).start()
+ )
+
+ try:
+ # Let it run briefly
+ time.sleep(1)
+
+ # Verify query is running
+ self.assertTrue(query.isActive, "Query should be active")
+ finally:
+ query.stop()
+
+
+if __name__ == "__main__":
+ from pyspark.testing import main
+
+ main()
diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py
b/python/pyspark/sql/tests/test_connect_compatibility.py
index 56b212387fe4..e0645586deac 100644
--- a/python/pyspark/sql/tests/test_connect_compatibility.py
+++ b/python/pyspark/sql/tests/test_connect_compatibility.py
@@ -487,7 +487,7 @@ class ConnectCompatibilityTestsMixin:
"""Test Data Stream Reader compatibility between classic and
connect."""
expected_missing_connect_properties = set()
expected_missing_classic_properties = set()
- expected_missing_connect_methods = set()
+ expected_missing_connect_methods = {"name"}
expected_missing_classic_methods = set()
self.check_compatibility(
ClassicDataStreamReader,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]