gaogaotiantian commented on code in PR #53898:
URL: https://github.com/apache/spark/pull/53898#discussion_r2714878087


##########
python/pyspark/sql/tests/streaming/test_streaming_reader_name.py:
##########
@@ -0,0 +1,183 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import time
+
+from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class DataStreamReaderNameTests(ReusedSQLTestCase):
+    """Test suite for DataStreamReader.name() functionality in PySpark."""
+
+    @classmethod
+    def setUpClass(cls):
+        super(DataStreamReaderNameTests, cls).setUpClass()
+        # Enable streaming source evolution feature
+        
cls.spark.conf.set("spark.sql.streaming.queryEvolution.enableSourceEvolution", 
"true")
+        cls.spark.conf.set("spark.sql.streaming.offsetLog.formatVersion", "2")
+
+    def test_name_with_valid_names(self):
+        """Test that various valid source name patterns work correctly."""
+        valid_names = [
+            "mySource",
+            "my_source",
+            "MySource123",
+            "_private",
+            "source_123_test",
+            "123source",
+        ]
+
+        for name in valid_names:
+            with tempfile.TemporaryDirectory(prefix=f"test_{name}_") as tmpdir:
+                self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+                df = (
+                    self.spark.readStream.format("parquet")
+                    .schema("id LONG")
+                    .name(name)
+                    .load(tmpdir)
+                )
+                self.assertTrue(df.isStreaming, f"DataFrame should be 
streaming for name: {name}")
+
+    def test_name_method_chaining(self):
+        """Test that name() returns the reader for method chaining."""
+        with tempfile.TemporaryDirectory(prefix="test_chaining_") as tmpdir:
+            self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+            df = (
+                self.spark.readStream.format("parquet")
+                .schema("id LONG")
+                .name("my_source")
+                .option("maxFilesPerTrigger", "1")
+                .load(tmpdir)
+            )
+
+            self.assertTrue(df.isStreaming, "DataFrame should be streaming")
+
+    def test_name_before_format(self):
+        """Test that order doesn't matter - name can be set before format."""
+        with tempfile.TemporaryDirectory(prefix="test_before_format_") as 
tmpdir:
+            self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+            df = (
+                self.spark.readStream.name("my_source")
+                .format("parquet")
+                .schema("id LONG")
+                .load(tmpdir)
+            )
+
+            self.assertTrue(df.isStreaming, "DataFrame should be streaming")
+
+    def test_invalid_names(self):
+        """Test that various invalid source names are rejected."""
+        invalid_names = [
+            "my-source",      # hyphen
+            "my source",      # space
+            "my.source",      # dot
+            "my@source",      # special char
+            "my$source",      # dollar sign
+            "my#source",      # hash
+            "my!source",      # exclamation
+        ]
+
+        for invalid_name in invalid_names:
+            with self.subTest(name=invalid_name):
+                with tempfile.TemporaryDirectory(prefix="test_invalid_") as 
tmpdir:
+                    
self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+                    with self.assertRaises(PySparkValueError) as context:
+                        self.spark.readStream.format("parquet").schema("id 
LONG").name(
+                            invalid_name
+                        ).load(tmpdir)
+
+                    # The error message should contain information about 
invalid name
+                    self.assertIn("source", str(context.exception).lower())
+
+    def test_invalid_name_empty_string(self):

Review Comment:
   We can probably put this case to invalid case too?



##########
python/pyspark/sql/tests/streaming/test_streaming_reader_name.py:
##########
@@ -0,0 +1,183 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import time
+
+from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class DataStreamReaderNameTests(ReusedSQLTestCase):
+    """Test suite for DataStreamReader.name() functionality in PySpark."""
+
+    @classmethod
+    def setUpClass(cls):
+        super(DataStreamReaderNameTests, cls).setUpClass()
+        # Enable streaming source evolution feature
+        
cls.spark.conf.set("spark.sql.streaming.queryEvolution.enableSourceEvolution", 
"true")
+        cls.spark.conf.set("spark.sql.streaming.offsetLog.formatVersion", "2")
+
+    def test_name_with_valid_names(self):
+        """Test that various valid source name patterns work correctly."""
+        valid_names = [
+            "mySource",
+            "my_source",
+            "MySource123",
+            "_private",
+            "source_123_test",
+            "123source",
+        ]
+
+        for name in valid_names:
+            with tempfile.TemporaryDirectory(prefix=f"test_{name}_") as tmpdir:
+                self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+                df = (
+                    self.spark.readStream.format("parquet")
+                    .schema("id LONG")
+                    .name(name)
+                    .load(tmpdir)
+                )
+                self.assertTrue(df.isStreaming, f"DataFrame should be 
streaming for name: {name}")
+
+    def test_name_method_chaining(self):
+        """Test that name() returns the reader for method chaining."""
+        with tempfile.TemporaryDirectory(prefix="test_chaining_") as tmpdir:
+            self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+            df = (
+                self.spark.readStream.format("parquet")
+                .schema("id LONG")
+                .name("my_source")
+                .option("maxFilesPerTrigger", "1")
+                .load(tmpdir)
+            )
+
+            self.assertTrue(df.isStreaming, "DataFrame should be streaming")
+
+    def test_name_before_format(self):
+        """Test that order doesn't matter - name can be set before format."""
+        with tempfile.TemporaryDirectory(prefix="test_before_format_") as 
tmpdir:
+            self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+            df = (
+                self.spark.readStream.name("my_source")
+                .format("parquet")
+                .schema("id LONG")
+                .load(tmpdir)
+            )
+
+            self.assertTrue(df.isStreaming, "DataFrame should be streaming")
+
+    def test_invalid_names(self):
+        """Test that various invalid source names are rejected."""
+        invalid_names = [
+            "my-source",      # hyphen
+            "my source",      # space
+            "my.source",      # dot
+            "my@source",      # special char
+            "my$source",      # dollar sign
+            "my#source",      # hash
+            "my!source",      # exclamation
+        ]
+
+        for invalid_name in invalid_names:
+            with self.subTest(name=invalid_name):
+                with tempfile.TemporaryDirectory(prefix="test_invalid_") as 
tmpdir:
+                    
self.spark.range(10).write.mode("overwrite").parquet(tmpdir)
+                    with self.assertRaises(PySparkValueError) as context:
+                        self.spark.readStream.format("parquet").schema("id 
LONG").name(
+                            invalid_name
+                        ).load(tmpdir)
+
+                    # The error message should contain information about 
invalid name
+                    self.assertIn("source", str(context.exception).lower())
+
+    def test_invalid_name_empty_string(self):
+        """Test that empty string is rejected."""
+        with self.assertRaises(PySparkValueError):
+            self.spark.readStream.format("rate").name("").load()
+
+    def test_invalid_name_none(self):

Review Comment:
   Maybe combine `None` and `123` together? They raise different error than 
those invalid names so it makes sense to put them in a separate test case.



##########
python/pyspark/sql/streaming/readwriter.py:
##########
@@ -241,6 +242,58 @@ def options(self, **options: "OptionalPrimitiveType") -> 
"DataStreamReader":
             self._jreader = self._jreader.option(k, to_str(options[k]))
         return self
 
+    def name(self, source_name: str) -> "DataStreamReader":
+        """Specifies a name for the streaming source.
+
+        This name is used to identify the source in checkpoint metadata and 
enables
+        stable checkpoint locations for source evolution.
+
+        .. versionadded:: 4.2.0
+
+        Parameters
+        ----------
+        source_name : str
+            the name to assign to this streaming source. Must contain only 
ASCII letters,
+            digits, and underscores.
+
+        Returns
+        -------
+        :class:`DataStreamReader`
+
+        Notes
+        -----
+        This API is experimental.
+
+        Examples
+        --------
+        >>> spark.readStream.format("rate").name("my_source")  # doctest: +SKIP
+        <...streaming.readwriter.DataStreamReader object ...>
+        """
+        if not isinstance(source_name, str):
+            raise PySparkTypeError(
+                errorClass="NOT_STR",
+                messageParameters={
+                    "arg_name": "source_name",
+                    "arg_type": type(source_name).__name__,
+                },
+            )
+
+        if not source_name or len(source_name.strip()) == 0:

Review Comment:
   Do we consider `"  "` an empty string or invalid name? I think we call fall 
into the next category and just claim spaces can't be used in source name. Also 
`str()` in `str(source_name)` is not needed because we know it's a string now :)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to