This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a98c885da09d [SPARK-46568][PYTHON] Make Python data source options a 
case-insensitive dictionary
a98c885da09d is described below

commit a98c885da09d45a19568f5d853086f747e0ecd95
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Thu Jan 4 22:28:49 2024 -0800

    [SPARK-46568][PYTHON] Make Python data source options a case-insensitive 
dictionary
    
    ### What changes were proposed in this pull request?
    
    This PR updates the `options` field to use a case-insensitive dictionary to 
keep the behavior consistent with the Scala side (which uses 
`CaseInsensitiveStringMap`). Currently, `options` are stored in a normal Python 
dictionary which can be confusing to users. For instance:
    ```python
    class MyDataSource(DataSource):
        def __init__(self, options):
            self.api_key = options.get("API_KEY") # <- This is None
    
    spark.read.format(..).option("API_KEY", my_key).load(...)
    ```
    Here, `options` will not have this "API_KEY" as everything is converted to 
lowercase on the Scala side.
    
    ### Why are the changes needed?
    
    To improve usability.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #44564 from allisonwang-db/spark-46568-ds-options.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 python/pyspark/sql/datasource.py                   | 51 +++++++++++++++++++---
 python/pyspark/sql/tests/test_python_datasource.py | 21 +++++++++
 python/pyspark/sql/worker/create_data_source.py    |  6 +--
 .../pyspark/sql/worker/write_into_data_source.py   |  6 +--
 .../execution/python/PythonDataSourceSuite.scala   | 43 ++++++++++++++++++
 5 files changed, 115 insertions(+), 12 deletions(-)

diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index e20d44039a69..bdedbac3544e 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -15,18 +15,25 @@
 # limitations under the License.
 #
 from abc import ABC, abstractmethod
-from typing import final, Any, Dict, Iterator, List, Sequence, Tuple, Type, 
Union, TYPE_CHECKING
+from collections import UserDict
+from typing import Any, Dict, Iterator, List, Sequence, Tuple, Type, Union, 
TYPE_CHECKING
 
 from pyspark.sql import Row
 from pyspark.sql.types import StructType
 from pyspark.errors import PySparkNotImplementedError
 
 if TYPE_CHECKING:
-    from pyspark.sql._typing import OptionalPrimitiveType
     from pyspark.sql.session import SparkSession
 
 
-__all__ = ["DataSource", "DataSourceReader", "DataSourceWriter", 
"DataSourceRegistration"]
+__all__ = [
+    "DataSource",
+    "DataSourceReader",
+    "DataSourceWriter",
+    "DataSourceRegistration",
+    "InputPartition",
+    "WriterCommitMessage",
+]
 
 
 class DataSource(ABC):
@@ -45,15 +52,14 @@ class DataSource(ABC):
     .. versionadded: 4.0.0
     """
 
-    @final
-    def __init__(self, options: Dict[str, "OptionalPrimitiveType"]) -> None:
+    def __init__(self, options: Dict[str, str]) -> None:
         """
         Initializes the data source with user-provided options.
 
         Parameters
         ----------
         options : dict
-            A dictionary representing the options for this data source.
+            A case-insensitive dictionary representing the options for this 
data source.
 
         Notes
         -----
@@ -403,3 +409,36 @@ class DataSourceRegistration:
         assert sc._jvm is not None
         ds = 
sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonDataSource(wrapped)
         self.sparkSession._jsparkSession.dataSource().registerPython(name, ds)
+
+
+class CaseInsensitiveDict(UserDict):
+    """
+    A case-insensitive map of string keys to values.
+
+    This is used by Python data source options to ensure consistent case 
insensitivity.
+    """
+
+    def __init__(self, *args: Any, **kwargs: Any) -> None:
+        super().__init__(*args, **kwargs)
+        self.update(*args, **kwargs)
+
+    def __setitem__(self, key: str, value: Any) -> None:
+        super().__setitem__(key.lower(), value)
+
+    def __getitem__(self, key: str) -> Any:
+        return super().__getitem__(key.lower())
+
+    def __delitem__(self, key: str) -> None:
+        super().__delitem__(key.lower())
+
+    def __contains__(self, key: object) -> bool:
+        if isinstance(key, str):
+            return super().__contains__(key.lower())
+        return False
+
+    def update(self, *args: Any, **kwargs: Any) -> None:
+        for k, v in dict(*args, **kwargs).items():
+            self[k] = v
+
+    def copy(self) -> "CaseInsensitiveDict":
+        return type(self)(self)
diff --git a/python/pyspark/sql/tests/test_python_datasource.py 
b/python/pyspark/sql/tests/test_python_datasource.py
index ce629b2718e2..79414cb7ed69 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -26,6 +26,7 @@ from pyspark.sql.datasource import (
     InputPartition,
     DataSourceWriter,
     WriterCommitMessage,
+    CaseInsensitiveDict,
 )
 from pyspark.sql.types import Row, StructType
 from pyspark.testing import assertDataFrameEqual
@@ -346,6 +347,26 @@ class BasePythonDataSourceTestsMixin:
                 text = file.read()
             assert text == "failed"
 
+    def test_case_insensitive_dict(self):
+        d = CaseInsensitiveDict({"foo": 1, "Bar": 2})
+        self.assertEqual(d["foo"], d["FOO"])
+        self.assertEqual(d["bar"], d["BAR"])
+        self.assertTrue("baR" in d)
+        d["BAR"] = 3
+        self.assertEqual(d["BAR"], 3)
+        # Test update
+        d.update({"BaZ": 3})
+        self.assertEqual(d["BAZ"], 3)
+        d.update({"FOO": 4})
+        self.assertEqual(d["foo"], 4)
+        # Test delete
+        del d["FoO"]
+        self.assertFalse("FOO" in d)
+        # Test copy
+        d2 = d.copy()
+        self.assertEqual(d2["BaR"], 3)
+        self.assertEqual(d2["baz"], 3)
+
 
 class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
     ...
diff --git a/python/pyspark/sql/worker/create_data_source.py 
b/python/pyspark/sql/worker/create_data_source.py
index 1ba4dc9e8a3c..a377911c6e9b 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -29,7 +29,7 @@ from pyspark.serializers import (
     write_with_length,
     SpecialLengths,
 )
-from pyspark.sql.datasource import DataSource
+from pyspark.sql.datasource import DataSource, CaseInsensitiveDict
 from pyspark.sql.types import _parse_datatype_json_string, StructType
 from pyspark.util import handle_worker_exception
 from pyspark.worker_util import (
@@ -120,7 +120,7 @@ def main(infile: IO, outfile: IO) -> None:
                 )
 
         # Receive the options.
-        options = dict()
+        options = CaseInsensitiveDict()
         num_options = read_int(infile)
         for _ in range(num_options):
             key = utf8_deserializer.loads(infile)
@@ -129,7 +129,7 @@ def main(infile: IO, outfile: IO) -> None:
 
         # Instantiate a data source.
         try:
-            data_source = data_source_cls(options=options)
+            data_source = data_source_cls(options=options)  # type: ignore
         except Exception as e:
             raise PySparkRuntimeError(
                 error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
diff --git a/python/pyspark/sql/worker/write_into_data_source.py 
b/python/pyspark/sql/worker/write_into_data_source.py
index 36b3c23b3379..0ba6fc6eb17f 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -30,7 +30,7 @@ from pyspark.serializers import (
     SpecialLengths,
 )
 from pyspark.sql import Row
-from pyspark.sql.datasource import DataSource, WriterCommitMessage
+from pyspark.sql.datasource import DataSource, WriterCommitMessage, 
CaseInsensitiveDict
 from pyspark.sql.types import (
     _parse_datatype_json_string,
     StructType,
@@ -142,7 +142,7 @@ def main(infile: IO, outfile: IO) -> None:
         return_col_name = return_type[0].name
 
         # Receive the options.
-        options = dict()
+        options = CaseInsensitiveDict()
         num_options = read_int(infile)
         for _ in range(num_options):
             key = utf8_deserializer.loads(infile)
@@ -153,7 +153,7 @@ def main(infile: IO, outfile: IO) -> None:
         overwrite = read_bool(infile)
 
         # Instantiate a data source.
-        data_source = data_source_cls(options=options)
+        data_source = data_source_cls(options=options)  # type: ignore
 
         # Instantiate the data source writer.
         writer = data_source.writer(schema, overwrite)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index a4614c2b4bad..1cd8fb6819cf 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -790,4 +790,47 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
       }
     }
   }
+
+  test("SPARK-46568: case insensitive options") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import (
+         |    DataSource, DataSourceReader, DataSourceWriter, 
WriterCommitMessage)
+         |class SimpleDataSourceReader(DataSourceReader):
+         |    def __init__(self, options):
+         |        self.options = options
+         |
+         |    def read(self, partition):
+         |        foo = self.options.get("Foo")
+         |        bar = self.options.get("BAR")
+         |        baz = "BaZ" in self.options
+         |        yield (foo, bar, baz)
+         |
+         |class SimpleDataSourceWriter(DataSourceWriter):
+         |    def __init__(self, options):
+         |        self.options = options
+         |
+         |    def write(self, row):
+         |        if "FOO" not in self.options or "BAR" not in self.options:
+         |            raise Exception("FOO or BAR not found")
+         |        return WriterCommitMessage()
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "a string, b string, c string"
+         |
+         |    def reader(self, schema):
+         |        return SimpleDataSourceReader(self.options)
+         |
+         |    def writer(self, schema, overwrite):
+         |        return SimpleDataSourceWriter(self.options)
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val df = spark.read.option("foo", 1).option("bar", 2).option("BAZ", 3)
+      .format(dataSourceName).load()
+    checkAnswer(df, Row("1", "2", "true"))
+    df.write.option("foo", 1).option("bar", 
2).format(dataSourceName).mode("append").save()
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to