This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8e1e126c68aa [SPARK-55121][PYTHON][SS] Add DataStreamReader.name() to 
Classic PySpark
8e1e126c68aa is described below

commit 8e1e126c68aaa94f0fa6861b93042e1385615418
Author: ericm-db <[email protected]>
AuthorDate: Fri Jan 23 10:08:31 2026 -0800

    [SPARK-55121][PYTHON][SS] Add DataStreamReader.name() to Classic PySpark
    
    ### What changes were proposed in this pull request?
    
    This PR adds the `name()` method to Classic PySpark's `DataStreamReader` 
class. This method allows users to specify a name for streaming sources, which 
is used in checkpoint metadata and enables stable checkpoint locations for 
source evolution.
    
    Changes include:
    - Add `name()` method to `DataStreamReader` in 
`python/pyspark/sql/streaming/readwriter.py`
    - Add comprehensive test suite in 
`python/pyspark/sql/tests/streaming/test_streaming_reader_name.py`
    - Update compatibility test to mark `name` as currently missing from 
Connect (until the Connect PR merges)
    
    The method validates that the source_name contains only ASCII letters, 
digits, and underscores, raising `PySparkTypeError` or `PySparkValueError` for 
invalid inputs.
    
    ### Why are the changes needed?
    
    This brings Classic PySpark to feature parity with the Scala/Java API for 
streaming source naming. The `name()` method is essential for:
    1. Identifying sources in checkpoint metadata
    2. Enabling stable checkpoint locations during source evolution
    3. Providing consistency across Classic and Connect implementations
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Users can now call `.name()` on DataStreamReader in Classic PySpark:
    ```python
    spark.readStream.format("parquet").name("my_source").load("/path")
    ```
    
    ### How was this patch tested?
    
    - Added comprehensive unit tests in `test_streaming_reader_name.py` 
covering:
      - Valid name patterns (letters, digits, underscores)
      - Invalid names (hyphens, spaces, dots, special characters, empty 
strings, None, wrong types)
      - Method chaining
      - Different data formats (parquet, json)
      - Integration with streaming queries
    - Updated compatibility tests to account for the current state where 
Classic has `name` but Connect doesn't yet
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #53898 from ericm-db/classic-datastream-reader-name.
    
    Authored-by: ericm-db <[email protected]>
    Signed-off-by: Anish Shrigondekar <[email protected]>
---
 dev/sparktestsupport/modules.py                    |   1 +
 python/pyspark/errors/error-conditions.json        |   5 +
 python/pyspark/sql/streaming/readwriter.py         |  47 ++++++
 .../tests/streaming/test_streaming_reader_name.py  | 179 +++++++++++++++++++++
 .../sql/tests/test_connect_compatibility.py        |   2 +-
 5 files changed, 233 insertions(+), 1 deletion(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 5cf0f745a151..6ba2e619f703 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -668,6 +668,7 @@ pyspark_structured_streaming = Module(
         "pyspark.sql.tests.streaming.test_streaming_foreach",
         "pyspark.sql.tests.streaming.test_streaming_foreach_batch",
         "pyspark.sql.tests.streaming.test_streaming_listener",
+        "pyspark.sql.tests.streaming.test_streaming_reader_name",
         "pyspark.sql.tests.pandas.test_pandas_grouped_map_with_state",
         "pyspark.sql.tests.pandas.streaming.test_pandas_transform_with_state",
         
"pyspark.sql.tests.pandas.streaming.test_pandas_transform_with_state_checkpoint_v2",
diff --git a/python/pyspark/errors/error-conditions.json 
b/python/pyspark/errors/error-conditions.json
index 7943c0992603..ee35e237b898 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -465,6 +465,11 @@
       "Parameter value <arg_name> must be a valid UUID format: <origin>"
     ]
   },
+  "INVALID_STREAMING_SOURCE_NAME": {
+    "message": [
+      "Invalid streaming source name '<source_name>'. Source names must 
contain only ASCII letters, digits, and underscores."
+    ]
+  },
   "INVALID_TIMEOUT_TIMESTAMP": {
     "message": [
       "Timeout timestamp (<timestamp>) cannot be earlier than the current 
watermark (<watermark>)."
diff --git a/python/pyspark/sql/streaming/readwriter.py 
b/python/pyspark/sql/streaming/readwriter.py
index ffb4415eaac3..50c28f49ea1f 100644
--- a/python/pyspark/sql/streaming/readwriter.py
+++ b/python/pyspark/sql/streaming/readwriter.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 #
 
+import re
 import sys
 from collections.abc import Iterator
 from typing import cast, overload, Any, Callable, List, Optional, 
TYPE_CHECKING, Union
@@ -241,6 +242,52 @@ class DataStreamReader(OptionUtils):
             self._jreader = self._jreader.option(k, to_str(options[k]))
         return self
 
+    def name(self, source_name: str) -> "DataStreamReader":
+        """Specifies a name for the streaming source.
+
+        This name is used to identify the source in checkpoint metadata and 
enables
+        stable checkpoint locations for source evolution.
+
+        .. versionadded:: 4.2.0
+
+        Parameters
+        ----------
+        source_name : str
+            the name to assign to this streaming source. Must contain only 
ASCII letters,
+            digits, and underscores.
+
+        Returns
+        -------
+        :class:`DataStreamReader`
+
+        Notes
+        -----
+        This API is experimental.
+
+        Examples
+        --------
+        >>> spark.readStream.format("rate").name("my_source")  # doctest: +SKIP
+        <...streaming.readwriter.DataStreamReader object ...>
+        """
+        if not isinstance(source_name, str):
+            raise PySparkTypeError(
+                errorClass="NOT_STR",
+                messageParameters={
+                    "arg_name": "source_name",
+                    "arg_type": type(source_name).__name__,
+                },
+            )
+
+        # Validate that source_name contains only ASCII letters, digits, and 
underscores
+        if not re.match(r"^[a-zA-Z0-9_]+$", source_name):
+            raise PySparkValueError(
+                errorClass="INVALID_STREAMING_SOURCE_NAME",
+                messageParameters={"source_name": source_name},
+            )
+
+        self._jreader = self._jreader.name(source_name)
+        return self
+
     def load(
         self,
         path: Optional[str] = None,
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_reader_name.py 
b/python/pyspark/sql/tests/streaming/test_streaming_reader_name.py
new file mode 100644
index 000000000000..8b3af847df7e
--- /dev/null
+++ b/python/pyspark/sql/tests/streaming/test_streaming_reader_name.py
@@ -0,0 +1,179 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import time
+
+from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class DataStreamReaderNameTests(ReusedSQLTestCase):
+    """Test suite for DataStreamReader.name() functionality in PySpark."""
+
+    @classmethod
+    def setUpClass(cls):
+        super(DataStreamReaderNameTests, cls).setUpClass()
+        # Enable streaming source evolution feature
+        
cls.spark.conf.set("spark.sql.streaming.queryEvolution.enableSourceEvolution", 
"true")
+        cls.spark.conf.set("spark.sql.streaming.offsetLog.formatVersion", "2")
+
+    def test_name_with_valid_names(self):
+        """Test that various valid source name patterns work correctly."""
+        valid_names = [
+            "mySource",
+            "my_source",
+            "MySource123",
+            "_private",
+            "source_123_test",
+            "123source",
+        ]
+
+        for name in valid_names:
+            with tempfile.TemporaryDirectory(prefix=f"test_{name}_") as tmpdir:
+                self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+                df = (
+                    self.spark.readStream.format("parquet")
+                    .schema("id LONG")
+                    .name(name)
+                    .load(tmpdir)
+                )
+                self.assertTrue(df.isStreaming, f"DataFrame should be 
streaming for name: {name}")
+
+    def test_name_method_chaining(self):
+        """Test that name() returns the reader for method chaining."""
+        with tempfile.TemporaryDirectory(prefix="test_chaining_") as tmpdir:
+            self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+            df = (
+                self.spark.readStream.format("parquet")
+                .schema("id LONG")
+                .name("my_source")
+                .option("maxFilesPerTrigger", "1")
+                .load(tmpdir)
+            )
+
+            self.assertTrue(df.isStreaming, "DataFrame should be streaming")
+
+    def test_name_before_format(self):
+        """Test that order doesn't matter - name can be set before format."""
+        with tempfile.TemporaryDirectory(prefix="test_before_format_") as 
tmpdir:
+            self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+            df = (
+                self.spark.readStream.name("my_source")
+                .format("parquet")
+                .schema("id LONG")
+                .load(tmpdir)
+            )
+
+            self.assertTrue(df.isStreaming, "DataFrame should be streaming")
+
+    def test_invalid_names(self):
+        """Test that various invalid source names are rejected."""
+        invalid_names = [
+            "",  # empty string
+            "  ",  # whitespace only
+            "my-source",  # hyphen
+            "my source",  # space
+            "my.source",  # dot
+            "my@source",  # special char
+            "my$source",  # dollar sign
+            "my#source",  # hash
+            "my!source",  # exclamation
+        ]
+
+        for invalid_name in invalid_names:
+            with self.subTest(name=invalid_name):
+                with tempfile.TemporaryDirectory(prefix="test_invalid_") as 
tmpdir:
+                    
self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+                    with self.assertRaises(PySparkValueError) as context:
+                        self.spark.readStream.format("parquet").schema("id 
LONG").name(
+                            invalid_name
+                        ).load(tmpdir)
+
+                    # The error message should contain information about 
invalid name
+                    self.assertIn("source", str(context.exception).lower())
+
+    def test_invalid_name_wrong_type(self):
+        """Test that None and non-string types are rejected."""
+        invalid_types = [None, 123, 45.67, [], {}]
+
+        for invalid_value in invalid_types:
+            with self.subTest(value=invalid_value):
+                with self.assertRaises(PySparkTypeError):
+                    
self.spark.readStream.format("rate").name(invalid_value).load()
+
+    def test_name_with_different_formats(self):
+        """Test that name() works with different streaming data sources."""
+        with tempfile.TemporaryDirectory(prefix="test_name_formats_") as 
tmpdir:
+            # Create test data
+            self.spark.range(10).write.mode("overwrite").parquet(tmpdir + 
"/parquet_data")
+            self.spark.range(10).selectExpr("id", "CAST(id AS STRING) as 
value").write.mode(
+                "overwrite"
+            ).json(tmpdir + "/json_data")
+
+            # Test with parquet
+            parquet_df = (
+                self.spark.readStream.format("parquet")
+                .name("parquet_source")
+                .schema("id LONG")
+                .load(tmpdir + "/parquet_data")
+            )
+            self.assertTrue(parquet_df.isStreaming, "Parquet DataFrame should 
be streaming")
+
+            # Test with json - specify schema
+            json_df = (
+                self.spark.readStream.format("json")
+                .name("json_source")
+                .schema("id LONG, value STRING")
+                .load(tmpdir + "/json_data")
+            )
+            self.assertTrue(json_df.isStreaming, "JSON DataFrame should be 
streaming")
+
+    def test_name_persists_through_query(self):
+        """Test that the name persists when starting a streaming query."""
+        with tempfile.TemporaryDirectory(prefix="test_name_query_") as tmpdir:
+            data_dir = tmpdir + "/data"
+            checkpoint_dir = tmpdir + "/checkpoint"
+
+            # Create test data
+            self.spark.range(10).write.mode("overwrite").parquet(data_dir)
+
+            df = (
+                self.spark.readStream.format("parquet")
+                .schema("id LONG")
+                .name("parquet_source_test")
+                .load(data_dir)
+            )
+
+            query = (
+                df.writeStream.format("noop").option("checkpointLocation", 
checkpoint_dir).start()
+            )
+
+            try:
+                # Let it run briefly
+                time.sleep(1)
+
+                # Verify query is running
+                self.assertTrue(query.isActive, "Query should be active")
+            finally:
+                query.stop()
+
+
+if __name__ == "__main__":
+    from pyspark.testing import main
+
+    main()
diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py 
b/python/pyspark/sql/tests/test_connect_compatibility.py
index 56b212387fe4..e0645586deac 100644
--- a/python/pyspark/sql/tests/test_connect_compatibility.py
+++ b/python/pyspark/sql/tests/test_connect_compatibility.py
@@ -487,7 +487,7 @@ class ConnectCompatibilityTestsMixin:
         """Test Data Stream Reader compatibility between classic and 
connect."""
         expected_missing_connect_properties = set()
         expected_missing_classic_properties = set()
-        expected_missing_connect_methods = set()
+        expected_missing_connect_methods = {"name"}
         expected_missing_classic_methods = set()
         self.check_compatibility(
             ClassicDataStreamReader,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to