This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6cb0cd5bcb5 [SPARK-45713][PYTHON] Support registering Python data 
sources
6cb0cd5bcb5 is described below

commit 6cb0cd5bcb502b07a6388ce4d831e152e90e534d
Author: allisonwang-db <[email protected]>
AuthorDate: Tue Oct 31 13:08:06 2023 +0800

    [SPARK-45713][PYTHON] Support registering Python data sources
    
    ### What changes were proposed in this pull request?
    
    This PR adds support for registering Python data sources.
    Users can register a Python data source using the class:
    
    ```python
    class MyDataSource(DataSource):
        ...
    
    spark.dataSource.register(MyDataSource)
    ```
    
    This will allow users to use the data source using its name (to be 
supported in SPARK-45639)
    ```python
    spark.read.format("MyDataSource").load()
    ```
    
    The data sources registered are stored in `sharedState` and can be accessed 
by all spark sessions.
    
    ### Why are the changes needed?
    
    To support Python data source.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. This PR adds a new API in spark session to support registering data 
sources.
    
    ### How was this patch tested?
    
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43566 from allisonwang-db/spark-45713-ds-register.
    
    Authored-by: allisonwang-db <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../src/main/resources/error/error-classes.json    |   8 +-
 .../CheckConnectJvmClientCompatibility.scala       |   2 +
 python/pyspark/sql/datasource.py                   |  75 ++++++--
 python/pyspark/sql/session.py                      |  15 ++
 python/pyspark/sql/tests/test_python_datasource.py |  16 +-
 python/pyspark/sql/worker/create_data_source.py    | 190 +++++++++++++++++++++
 .../spark/sql/errors/QueryCompilationErrors.scala  |  16 +-
 .../apache/spark/sql/DataSourceRegistration.scala  |  48 ++++++
 .../scala/org/apache/spark/sql/SparkSession.scala  |   5 +
 .../execution/datasources/DataSourceManager.scala  |  54 ++++++
 .../python/UserDefinedPythonDataSource.scala       | 129 ++++++++++++--
 .../apache/spark/sql/internal/SharedState.scala    |  12 ++
 .../apache/spark/sql/IntegratedUDFTestUtils.scala  |  13 +-
 .../execution/python/PythonDataSourceSuite.scala   | 151 +++++++++++++---
 14 files changed, 672 insertions(+), 62 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index f9cc0a86521..278011b8cc8 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -812,6 +812,12 @@
     ],
     "sqlState" : "42K01"
   },
+  "DATA_SOURCE_ALREADY_EXISTS" : {
+    "message" : [
+      "Data source '<provider>' already exists in the registry. Please use a 
different name for the new data source."
+    ],
+    "sqlState" : "42710"
+  },
   "DATA_SOURCE_NOT_FOUND" : {
     "message" : [
       "Failed to find the data source: <provider>. Please find packages at 
`https://spark.apache.org/third-party-projects.html`.";
@@ -2774,7 +2780,7 @@
   },
   "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON" : {
     "message" : [
-      "Failed to plan Python data source <type> in Python: <msg>"
+      "Failed to <action> Python data source <type> in Python: <msg>"
     ],
     "sqlState" : "38000"
   },
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 7ddb339b12d..fb4f80998fc 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -177,6 +177,7 @@ object CheckConnectJvmClientCompatibility {
         "org.apache.spark.sql.SparkSessionExtensionsProvider"),
       
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDTFRegistration"),
       
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDFRegistration$"),
+      
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataSourceRegistration"),
 
       // DataFrame Reader & Writer
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameReader.json"), 
// rdd
@@ -227,6 +228,7 @@ object CheckConnectJvmClientCompatibility {
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.listenerManager"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.experimental"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udtf"),
+      
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.dataSource"),
       
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataFrame"),
       ProblemFilters.exclude[Problem](
         "org.apache.spark.sql.SparkSession.baseRelationToDataFrame"),
diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index 29dc194b7f8..5cda6596b3f 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -15,16 +15,18 @@
 # limitations under the License.
 #
 from abc import ABC, abstractmethod
-from typing import Any, Dict, Iterator, Tuple, Union, TYPE_CHECKING
+from typing import final, Any, Dict, Iterator, List, Optional, Tuple, Type, 
Union, TYPE_CHECKING
 
+from pyspark import since
 from pyspark.sql import Row
 from pyspark.sql.types import StructType
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import OptionalPrimitiveType
+    from pyspark.sql.session import SparkSession
 
 
-__all__ = ["DataSource", "DataSourceReader"]
+__all__ = ["DataSource", "DataSourceReader", "DataSourceRegistration"]
 
 
 class DataSource(ABC):
@@ -41,23 +43,35 @@ class DataSource(ABC):
     ``spark.read.format(...).load()`` and save data using 
``df.write.format(...).save()``.
     """
 
-    def __init__(self, options: Dict[str, "OptionalPrimitiveType"]):
+    @final
+    def __init__(
+        self,
+        paths: List[str],
+        userSpecifiedSchema: Optional[StructType],
+        options: Dict[str, "OptionalPrimitiveType"],
+    ) -> None:
         """
-        Initializes the data source with user-provided options.
+        Initializes the data source with user-provided information.
 
         Parameters
         ----------
+        paths : list
+            A list of paths to the data source.
+        userSpecifiedSchema : StructType, optional
+            The user-specified schema of the data source.
         options : dict
             A dictionary representing the options for this data source.
 
         Notes
         -----
-        This method should not contain any non-serializable objects.
+        This method should not be overridden.
         """
+        self.paths = paths
+        self.userSpecifiedSchema = userSpecifiedSchema
         self.options = options
 
-    @property
-    def name(self) -> str:
+    @classmethod
+    def name(cls) -> str:
         """
         Returns a string represents the format name of this data source.
 
@@ -66,20 +80,21 @@ class DataSource(ABC):
 
         Examples
         --------
-        >>> def name(self):
+        >>> def name(cls):
         ...     return "my_data_source"
         """
-        return self.__class__.__name__
+        return cls.__name__
 
     def schema(self) -> Union[StructType, str]:
         """
         Returns the schema of the data source.
 
-        It can reference the ``options`` field to infer the data source's 
schema when
-        users do not explicitly specify it. This method is invoked once when 
calling
-        ``spark.read.format(...).load()`` to get the schema for a data source 
read
-        operation. If this method is not implemented, and a user does not 
provide a
-        schema when reading the data source, an exception will be thrown.
+        It can refer any field initialized in the ``__init__`` method to infer 
the
+        data source's schema when users do not explicitly specify it. This 
method is
+        invoked once when calling ``spark.read.format(...).load()`` to get the 
schema
+        for a data source read operation. If this method is not implemented, 
and a
+        user does not provide a schema when reading the data source, an 
exception will
+        be thrown.
 
         Returns
         -------
@@ -212,3 +227,35 @@ class DataSourceReader(ABC):
         ...     yield Row(partition=partition, value=1)
         """
         ...
+
+
+@since(4.0)
+class DataSourceRegistration:
+    """
+    Wrapper for data source registration. This instance can be accessed by
+    :attr:`spark.dataSource`.
+    """
+
+    def __init__(self, sparkSession: "SparkSession"):
+        self.sparkSession = sparkSession
+
+    def register(
+        self,
+        dataSource: Type["DataSource"],
+    ) -> None:
+        """Register a Python user-defined data source.
+
+        Parameters
+        ----------
+        dataSource : type
+            The data source class to be registered. It should be a subclass of 
DataSource.
+        """
+        from pyspark.sql.udf import _wrap_function
+
+        name = dataSource.name()
+        sc = self.sparkSession.sparkContext
+        # Serialize the data source class.
+        wrapped = _wrap_function(sc, dataSource)
+        assert sc._jvm is not None
+        ds = 
sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonDataSource(wrapped)
+        self.sparkSession._jsparkSession.dataSource().registerPython(name, ds)
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 5c07ca607b7..1ffb602ba86 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -74,6 +74,7 @@ if TYPE_CHECKING:
     from pyspark.sql.streaming import StreamingQueryManager
     from pyspark.sql.udf import UDFRegistration
     from pyspark.sql.udtf import UDTFRegistration
+    from pyspark.sql.datasource import DataSourceRegistration
 
     # Running MyPy type checks will always require pandas and
     # other dependencies so importing here is fine.
@@ -863,6 +864,20 @@ class SparkSession(SparkConversionMixin):
 
         return UDTFRegistration(self)
 
+    @property
+    def dataSource(self) -> "DataSourceRegistration":
+        """Returns a :class:`DataSourceRegistration` for data source 
registration.
+
+        .. versionadded:: 4.0.0
+
+        Returns
+        -------
+        :class:`DataSourceRegistration`
+        """
+        from pyspark.sql.datasource import DataSourceRegistration
+
+        return DataSourceRegistration(self)
+
     def range(
         self,
         start: int,
diff --git a/python/pyspark/sql/tests/test_python_datasource.py 
b/python/pyspark/sql/tests/test_python_datasource.py
index 3ac103d2c7c..6584312dda4 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -26,9 +26,9 @@ class BasePythonDataSourceTestsMixin:
             ...
 
         options = dict(a=1, b=2)
-        ds = MyDataSource(options)
+        ds = MyDataSource(paths=[], userSpecifiedSchema=None, options=options)
         self.assertEqual(ds.options, options)
-        self.assertEqual(ds.name, "MyDataSource")
+        self.assertEqual(ds.name(), "MyDataSource")
         with self.assertRaises(NotImplementedError):
             ds.schema()
         with self.assertRaises(NotImplementedError):
@@ -43,6 +43,18 @@ class BasePythonDataSourceTestsMixin:
         self.assertEqual(list(reader.partitions()), [None])
         self.assertEqual(list(reader.read(None)), [(None,)])
 
+    def test_register_data_source(self):
+        class MyDataSource(DataSource):
+            ...
+
+        self.spark.dataSource.register(MyDataSource)
+
+        self.assertTrue(
+            self.spark._jsparkSession.sharedState()
+            .dataSourceRegistry()
+            .dataSourceExists("MyDataSource")
+        )
+
 
 class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
     ...
diff --git a/python/pyspark/sql/worker/create_data_source.py 
b/python/pyspark/sql/worker/create_data_source.py
new file mode 100644
index 00000000000..ea56d2cc752
--- /dev/null
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -0,0 +1,190 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+from typing import IO, List
+
+from pyspark.accumulators import _accumulatorRegistry
+from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+    read_bool,
+    read_int,
+    write_int,
+    write_with_length,
+    SpecialLengths,
+)
+from pyspark.sql.datasource import DataSource
+from pyspark.sql.types import _parse_datatype_json_string, StructType
+from pyspark.util import handle_worker_exception
+from pyspark.worker_util import (
+    check_python_version,
+    read_command,
+    pickleSer,
+    send_accumulator_updates,
+    setup_broadcasts,
+    setup_memory_limits,
+    setup_spark_files,
+    utf8_deserializer,
+)
+
+
+def main(infile: IO, outfile: IO) -> None:
+    """
+    Main method for creating a Python data source instance.
+
+    This process is invoked from the 
`UserDefinedPythonDataSourceRunner.runInPython` method
+    in JVM. This process is responsible for creating a `DataSource` object and 
send the
+    information needed back to the JVM.
+
+    The JVM sends the following information to this process:
+    - a `DataSource` class representing the data source to be created.
+    - a provider name in string.
+    - a list of paths in string.
+    - an optional user-specified schema in json string.
+    - a dictionary of options in string.
+
+    This process then creates a `DataSource` instance using the above 
information and
+    sends the pickled instance as well as the schema back to the JVM.
+    """
+    try:
+        check_python_version(infile)
+
+        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
+        setup_memory_limits(memory_limit_mb)
+
+        setup_spark_files(infile)
+        setup_broadcasts(infile)
+
+        _accumulatorRegistry.clear()
+
+        # Receive the data source class.
+        data_source_cls = read_command(pickleSer, infile)
+        if not (isinstance(data_source_cls, type) and 
issubclass(data_source_cls, DataSource)):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "a subclass of DataSource",
+                    "actual": f"'{type(data_source_cls).__name__}'",
+                },
+            )
+
+        # Receive the provider name.
+        provider = utf8_deserializer.loads(infile)
+        if provider.lower() != data_source_cls.name().lower():
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": f"provider with name {data_source_cls.name()}",
+                    "actual": f"'{provider}'",
+                },
+            )
+
+        # Receive the paths.
+        num_paths = read_int(infile)
+        paths: List[str] = []
+        for _ in range(num_paths):
+            paths.append(utf8_deserializer.loads(infile))
+
+        # Receive the user-specified schema
+        user_specified_schema = None
+        if read_bool(infile):
+            user_specified_schema = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
+            if not isinstance(user_specified_schema, StructType):
+                raise PySparkAssertionError(
+                    error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                    message_parameters={
+                        "expected": "the user-defined schema to be a 
'StructType'",
+                        "actual": f"'{type(data_source_cls).__name__}'",
+                    },
+                )
+
+        # Receive the options.
+        options = dict()
+        num_options = read_int(infile)
+        for _ in range(num_options):
+            key = utf8_deserializer.loads(infile)
+            value = utf8_deserializer.loads(infile)
+            options[key] = value
+
+        # Instantiate a data source.
+        try:
+            data_source = data_source_cls(
+                paths=paths,
+                userSpecifiedSchema=user_specified_schema,  # type: ignore
+                options=options,
+            )
+        except Exception as e:
+            raise PySparkRuntimeError(
+                error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
+                message_parameters={"type": "instance", "error": str(e)},
+            )
+
+        # Get the schema of the data source.
+        # If user_specified_schema is not None, use user_specified_schema.
+        # Otherwise, use the schema of the data source.
+        # Throw exception if the data source does not implement schema().
+        is_ddl_string = False
+        if user_specified_schema is None:
+            try:
+                schema = data_source.schema()
+                if isinstance(schema, str):
+                    # Here we cannot use _parse_datatype_string to parse the 
DDL string schema.
+                    # as it requires an active Spark session.
+                    is_ddl_string = True
+            except NotImplementedError:
+                raise PySparkRuntimeError(
+                    error_class="PYTHON_DATA_SOURCE_METHOD_NOT_IMPLEMENTED",
+                    message_parameters={"type": "instance", "method": 
"schema"},
+                )
+        else:
+            schema = user_specified_schema  # type: ignore
+
+        assert schema is not None
+
+        # Return the pickled data source instance.
+        pickleSer._write_with_length(data_source, outfile)
+
+        # Return the schema of the data source.
+        write_int(int(is_ddl_string), outfile)
+        if is_ddl_string:
+            write_with_length(schema.encode("utf-8"), outfile)  # type: ignore
+        else:
+            write_with_length(schema.json().encode("utf-8"), outfile)  # type: 
ignore
+
+    except BaseException as e:
+        handle_worker_exception(e, outfile)
+        sys.exit(-1)
+
+    send_accumulator_updates(outfile)
+
+    # check end of stream
+    if read_int(infile) == SpecialLengths.END_OF_STREAM:
+        write_int(SpecialLengths.END_OF_STREAM, outfile)
+    else:
+        # write a different value to tell JVM to not reuse this worker
+        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+        sys.exit(-1)
+
+
+if __name__ == "__main__":
+    # Read information about how to connect back to the JVM from the 
environment.
+    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
+    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
+    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    main(sock_file, sock_file)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 5fcd71d8bf9..4b28eadfec6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -1174,10 +1174,14 @@ private[sql] object QueryCompilationErrors extends 
QueryErrorsBase with Compilat
   }
 
   def schemaIsNotStructTypeError(exp: Expression, dataType: DataType): 
Throwable = {
+    schemaIsNotStructTypeError(toSQLExpr(exp), dataType)
+  }
+
+  def schemaIsNotStructTypeError(inputSchema: String, dataType: DataType): 
Throwable = {
     new AnalysisException(
       errorClass = "INVALID_SCHEMA.NON_STRUCT_TYPE",
       messageParameters = Map(
-        "inputSchema" -> toSQLExpr(exp),
+        "inputSchema" -> inputSchema,
         "dataType" -> toSQLType(dataType)
       ))
   }
@@ -1981,10 +1985,10 @@ private[sql] object QueryCompilationErrors extends 
QueryErrorsBase with Compilat
     )
   }
 
-  def failToPlanDataSourceError(tpe: String, msg: String): Throwable = {
+  def failToPlanDataSourceError(action: String, tpe: String, msg: String): 
Throwable = {
     new AnalysisException(
       errorClass = "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON",
-      messageParameters = Map("type" -> tpe, "msg" -> msg)
+      messageParameters = Map("action" -> action, "type" -> tpe, "msg" -> msg)
     )
   }
 
@@ -3801,4 +3805,10 @@ private[sql] object QueryCompilationErrors extends 
QueryErrorsBase with Compilat
         "functionName" -> functionName,
         "reason" -> reason))
   }
+
+  def dataSourceAlreadyExists(name: String): Throwable = {
+    new AnalysisException(
+      errorClass = "DATA_SOURCE_ALREADY_EXISTS",
+      messageParameters = Map("provider" -> name))
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
new file mode 100644
index 00000000000..15d26418984
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.Evolving
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.execution.datasources.DataSourceManager
+import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource
+
+/**
+ * Functions for registering user-defined data sources.
+ * Use `SparkSession.dataSource` to access this.
+ */
+@Evolving
+private[sql] class DataSourceRegistration private[sql] (dataSourceManager: 
DataSourceManager)
+  extends Logging {
+
+  protected[sql] def registerPython(
+      name: String,
+      dataSource: UserDefinedPythonDataSource): Unit = {
+    log.debug(
+      s"""
+         | Registering new Python data source:
+         | name: $name
+         | command: ${dataSource.dataSourceCls.command}
+         | envVars: ${dataSource.dataSourceCls.envVars}
+         | pythonIncludes: ${dataSource.dataSourceCls.pythonIncludes}
+         | pythonExec: ${dataSource.dataSourceCls.pythonExec}
+      """.stripMargin)
+
+    dataSourceManager.registerDataSource(name, dataSource.builder)
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 59aa17baa7f..aec40c845fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -228,6 +228,11 @@ class SparkSession private(
 
   def udtf: UDTFRegistration = sessionState.udtfRegistration
 
+  /**
+   * A collection of methods for registering user-defined data sources.
+   */
+  private[sql] def dataSource: DataSourceRegistration = 
sharedState.dataSourceRegistration
+
   /**
    * Returns a `StreamingQueryManager` that allows managing all the
    * `StreamingQuery`s active on `this`.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
new file mode 100644
index 00000000000..283ca2ac62e
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import java.util.Locale
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class DataSourceManager {
+
+  private type DataSourceBuilder = (
+    SparkSession,  // Spark session
+    String,  // provider name
+    Seq[String],  // paths
+    Option[StructType],  // user specified schema
+    CaseInsensitiveStringMap  // options
+  ) => LogicalPlan
+
+  private val dataSourceBuilders = new ConcurrentHashMap[String, 
DataSourceBuilder]()
+
+  private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)
+
+  def registerDataSource(name: String, builder: DataSourceBuilder): Unit = {
+    val normalizedName = normalize(name)
+    if (dataSourceBuilders.containsKey(normalizedName)) {
+      throw QueryCompilationErrors.dataSourceAlreadyExists(name)
+    }
+    // TODO(SPARK-45639): check if the data source is a DSv1 or DSv2 using 
loadDataSource.
+    dataSourceBuilders.put(normalizedName, builder)
+  }
+
+  def dataSourceExists(name: String): Boolean =
+    dataSourceBuilders.containsKey(normalize(name))
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index faae4101fa0..dbff8eefcd5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -20,25 +20,133 @@ package org.apache.spark.sql.execution.python
 import java.io.{DataInputStream, DataOutputStream}
 
 import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
 
 import net.razorvine.pickle.Pickler
 
-import org.apache.spark.api.python.{PythonFunction, PythonWorkerUtils, 
SpecialLengths}
+import org.apache.spark.api.python.{PythonFunction, PythonWorkerUtils, 
SimplePythonFunction, SpecialLengths}
 import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
-import org.apache.spark.sql.catalyst.plans.logical.PythonDataSource
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, 
PythonDataSource}
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
 import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
 /**
  * A user-defined Python data source. This is used by the Python API.
+ *
+ * @param dataSourceCls The Python data source class.
+ */
+case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
+
+  def builder(
+      sparkSession: SparkSession,
+      provider: String,
+      paths: Seq[String],
+      userSpecifiedSchema: Option[StructType],
+      options: CaseInsensitiveStringMap): LogicalPlan = {
+
+    val runner = new UserDefinedPythonDataSourceRunner(
+      dataSourceCls, provider, paths, userSpecifiedSchema, options)
+
+    val result = runner.runInPython()
+    val pickledDataSourceInstance = result.dataSource
+
+    val dataSource = SimplePythonFunction(
+      command = pickledDataSourceInstance,
+      envVars = dataSourceCls.envVars,
+      pythonIncludes = dataSourceCls.pythonIncludes,
+      pythonExec = dataSourceCls.pythonExec,
+      pythonVer = dataSourceCls.pythonVer,
+      broadcastVars = dataSourceCls.broadcastVars,
+      accumulator = dataSourceCls.accumulator)
+    val schema = result.schema
+
+    PythonDataSource(dataSource, schema, output = toAttributes(schema))
+  }
+
+  def apply(
+      sparkSession: SparkSession,
+      provider: String,
+      paths: Seq[String] = Seq.empty,
+      userSpecifiedSchema: Option[StructType] = None,
+      options: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty): 
DataFrame = {
+    val plan = builder(sparkSession, provider, paths, userSpecifiedSchema, 
options)
+    Dataset.ofRows(sparkSession, plan)
+  }
+}
+
+/**
+ * Used to store the result of creating a Python data source in the Python 
process.
  */
-case class UserDefinedPythonDataSource(
-    dataSource: PythonFunction,
-    schema: StructType) {
-  def apply(session: SparkSession): DataFrame = {
-    val source = PythonDataSource(dataSource, schema, output = 
toAttributes(schema))
-    Dataset.ofRows(session, source)
+case class PythonDataSourceCreationResult(
+    dataSource: Array[Byte],
+    schema: StructType)
+
+/**
+ * A runner used to create a Python data source in a Python process and return 
the result.
+ */
+class UserDefinedPythonDataSourceRunner(
+    dataSourceCls: PythonFunction,
+    provider: String,
+    paths: Seq[String],
+    userSpecifiedSchema: Option[StructType],
+    options: CaseInsensitiveStringMap)
+  extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) {
+
+  override val workerModule = "pyspark.sql.worker.create_data_source"
+
+  override protected def writeToPython(dataOut: DataOutputStream, pickler: 
Pickler): Unit = {
+    // Send python data source
+    PythonWorkerUtils.writePythonFunction(dataSourceCls, dataOut)
+
+    // Send the provider name
+    PythonWorkerUtils.writeUTF(provider, dataOut)
+
+    // Send the paths
+    dataOut.writeInt(paths.length)
+    paths.foreach(PythonWorkerUtils.writeUTF(_, dataOut))
+
+    // Send the user-specified schema, if provided
+    dataOut.writeBoolean(userSpecifiedSchema.isDefined)
+    userSpecifiedSchema.map(_.json).foreach(PythonWorkerUtils.writeUTF(_, 
dataOut))
+
+    // Send the options
+    dataOut.writeInt(options.size)
+    options.entrySet.asScala.foreach { e =>
+      PythonWorkerUtils.writeUTF(e.getKey, dataOut)
+      PythonWorkerUtils.writeUTF(e.getValue, dataOut)
+    }
+  }
+
+  override protected def receiveFromPython(
+      dataIn: DataInputStream): PythonDataSourceCreationResult = {
+    // Receive the pickled data source or an exception raised in Python worker.
+    val length = dataIn.readInt()
+    if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryCompilationErrors.failToPlanDataSourceError(
+        action = "create", tpe = "instance", msg = msg)
+    }
+
+    // Receive the pickled data source.
+    val pickledDataSourceInstance: Array[Byte] = 
PythonWorkerUtils.readBytes(length, dataIn)
+
+    // Receive the schema.
+    val isDDLString = dataIn.readInt()
+    val schemaStr = PythonWorkerUtils.readUTF(dataIn)
+    val schema = if (isDDLString == 1) {
+      DataType.fromDDL(schemaStr)
+    } else {
+      DataType.fromJson(schemaStr)
+    }
+    if (!schema.isInstanceOf[StructType]) {
+      throw QueryCompilationErrors.schemaIsNotStructTypeError(schemaStr, 
schema)
+    }
+
+    PythonDataSourceCreationResult(
+      dataSource = pickledDataSourceInstance,
+      schema = schema.asInstanceOf[StructType])
   }
 }
 
@@ -66,7 +174,8 @@ class UserDefinedPythonDataSourceReadRunner(
     val length = dataIn.readInt()
     if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
       val msg = PythonWorkerUtils.readUTF(dataIn)
-      throw QueryCompilationErrors.failToPlanDataSourceError("read", msg)
+      throw QueryCompilationErrors.failToPlanDataSourceError(
+        action = "plan", tpe = "read", msg = msg)
     }
 
     // Receive the pickled 'read' function.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index 164710cdd88..8adc32fcf62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -30,9 +30,11 @@ import org.apache.hadoop.fs.{FsUrlStreamHandlerFactory, Path}
 
 import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.internal.Logging
+import org.apache.spark.sql.DataSourceRegistration
 import org.apache.spark.sql.catalyst.catalog._
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.CacheManager
+import org.apache.spark.sql.execution.datasources.DataSourceManager
 import org.apache.spark.sql.execution.streaming.StreamExecution
 import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, 
SQLAppStatusStore, SQLTab, StreamingQueryStatusStore}
 import org.apache.spark.sql.internal.StaticSQLConf._
@@ -105,6 +107,16 @@ private[sql] class SharedState(
   @GuardedBy("activeQueriesLock")
   private[sql] val activeStreamingQueries = new ConcurrentHashMap[UUID, 
StreamExecution]()
 
+  /**
+   * A data source manager shared by all sessions.
+   */
+  lazy val dataSourceManager = new DataSourceManager()
+
+  /**
+   * A collection of method used for registering user-defined data sources.
+   */
+  lazy val dataSourceRegistration = new 
DataSourceRegistration(dataSourceManager)
+
   /**
    * A status store to query SQL status/metrics of this Spark application, 
based on SQL-specific
    * [[org.apache.spark.scheduler.SparkListenerEvent]]s.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
index 309d25ede4f..eab3b7d81b8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
@@ -232,8 +232,8 @@ object IntegratedUDFTestUtils extends SQLHelper {
             "from pyspark.serializers import CloudPickleSerializer; " +
               s"f = open('$path', 'wb');" +
               s"exec(open('$codePath', 'r').read());" +
-              s"ds = $dataSourceName(options=dict());" +
-              "f.write(CloudPickleSerializer().dumps(ds))"),
+              s"dataSourceCls = $dataSourceName;" +
+              "f.write(CloudPickleSerializer().dumps(dataSourceCls))"),
           None,
           "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
         binaryPythonDataSource = Files.readAllBytes(path.toPath)
@@ -425,19 +425,16 @@ object IntegratedUDFTestUtils extends SQLHelper {
 
   def createUserDefinedPythonDataSource(
       name: String,
-      pythonScript: String,
-      schema: StructType): UserDefinedPythonDataSource = {
+      pythonScript: String): UserDefinedPythonDataSource = {
     UserDefinedPythonDataSource(
-      dataSource = SimplePythonFunction(
+      dataSourceCls = SimplePythonFunction(
         command = createPythonDataSource(name, pythonScript),
         envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, 
String]],
         pythonIncludes = List.empty[String].asJava,
         pythonExec = pythonExec,
         pythonVer = pythonVer,
         broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
-        accumulator = null),
-      schema = schema
-    )
+        accumulator = null))
   }
 
   def createUserDefinedPythonTableFunction(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index b038bcd5529..6c749c2c9b6 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -25,26 +25,34 @@ import org.apache.spark.sql.types.StructType
 class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
   import IntegratedUDFTestUtils._
 
+  private def dataSourceName = "SimpleDataSource"
+  private def simpleDataSourceReaderScript: String =
+    """
+      |class SimpleDataSourceReader(DataSourceReader):
+      |    def partitions(self):
+      |        return range(0, 2)
+      |    def read(self, partition):
+      |        yield (0, partition)
+      |        yield (1, partition)
+      |        yield (2, partition)
+      |""".stripMargin
+
   test("simple data source") {
+    assume(shouldTestPythonUDFs)
     val dataSourceScript =
-      """
+      s"""
         |from pyspark.sql.datasource import DataSource, DataSourceReader
-        |class MyDataSourceReader(DataSourceReader):
-        |    def partitions(self):
-        |        return range(0, 2)
-        |    def read(self, partition):
-        |        yield (0, partition)
-        |        yield (1, partition)
-        |        yield (2, partition)
+        |$simpleDataSourceReaderScript
         |
-        |class MyDataSource(DataSource):
+        |class $dataSourceName(DataSource):
         |    def reader(self, schema):
-        |        return MyDataSourceReader()
+        |        return SimpleDataSourceReader()
         |""".stripMargin
     val schema = StructType.fromDDL("id INT, partition INT")
     val dataSource = createUserDefinedPythonDataSource(
-      name = "MyDataSource", pythonScript = dataSourceScript, schema = schema)
-    val df = dataSource(spark)
+      name = dataSourceName, pythonScript = dataSourceScript)
+    val df = dataSource.apply(
+      spark, provider = dataSourceName, userSpecifiedSchema = Some(schema))
     assert(df.rdd.getNumPartitions == 2)
     val plan = df.queryExecution.optimizedPlan
     plan match {
@@ -55,36 +63,130 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
     checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), 
Row(2, 1)))
   }
 
+  test("simple data source with string schema") {
+    assume(shouldTestPythonUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource, DataSourceReader
+         |$simpleDataSourceReaderScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT, partition INT"
+         |
+         |    def reader(self, schema):
+         |        return SimpleDataSourceReader()
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    val df = dataSource(spark, provider = dataSourceName)
+    checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), 
Row(2, 1)))
+  }
+
+  test("simple data source with StructType schema") {
+    assume(shouldTestPythonUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource, DataSourceReader
+         |from pyspark.sql.types import IntegerType, StructType, StructField
+         |$simpleDataSourceReaderScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return StructType([
+         |            StructField("id", IntegerType()),
+         |            StructField("partition", IntegerType())
+         |        ])
+         |
+         |    def reader(self, schema):
+         |        return SimpleDataSourceReader()
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    val df = dataSource(spark, provider = dataSourceName)
+    checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), 
Row(2, 1)))
+  }
+
+  test("data source with invalid schema") {
+    assume(shouldTestPythonUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource, DataSourceReader
+         |$simpleDataSourceReaderScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "INT"
+         |
+         |    def reader(self, schema):
+         |        return SimpleDataSourceReader()
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    checkError(
+      exception = intercept[AnalysisException](dataSource(spark, provider = 
dataSourceName)),
+      errorClass = "INVALID_SCHEMA.NON_STRUCT_TYPE",
+      parameters = Map("inputSchema" -> "INT", "dataType" -> "\"INT\""))
+  }
+
+  test("register data source") {
+    assume(shouldTestPythonUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource, DataSourceReader
+         |$simpleDataSourceReaderScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT, partition INT"
+         |
+         |    def reader(self, schema):
+         |        return SimpleDataSourceReader()
+         |""".stripMargin
+
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    
assert(spark.sharedState.dataSourceManager.dataSourceExists(dataSourceName))
+
+    // Check error when registering a data source with the same name.
+    val err = intercept[AnalysisException] {
+      spark.dataSource.registerPython(dataSourceName, dataSource)
+    }
+    checkError(
+      exception = err,
+      errorClass = "DATA_SOURCE_ALREADY_EXISTS",
+      parameters = Map("provider" -> dataSourceName))
+  }
+
   test("reader not implemented") {
+    assume(shouldTestPythonUDFs)
     val dataSourceScript =
-      """
+       s"""
         |from pyspark.sql.datasource import DataSource, DataSourceReader
-        |class MyDataSource(DataSource):
+        |class $dataSourceName(DataSource):
         |    pass
         |""".stripMargin
     val schema = StructType.fromDDL("id INT, partition INT")
     val dataSource = createUserDefinedPythonDataSource(
-      name = "MyDataSource", pythonScript = dataSourceScript, schema = schema)
+      name = dataSourceName, pythonScript = dataSourceScript)
     val err = intercept[AnalysisException] {
-      dataSource(spark).collect()
+      dataSource(spark, dataSourceName, userSpecifiedSchema = 
Some(schema)).collect()
     }
     assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON")
     
assert(err.getMessage.contains("PYTHON_DATA_SOURCE_METHOD_NOT_IMPLEMENTED"))
   }
 
   test("error creating reader") {
+    assume(shouldTestPythonUDFs)
     val dataSourceScript =
-      """
+      s"""
         |from pyspark.sql.datasource import DataSource
-        |class MyDataSource(DataSource):
+        |class $dataSourceName(DataSource):
         |    def reader(self, schema):
         |        raise Exception("error creating reader")
         |""".stripMargin
     val schema = StructType.fromDDL("id INT, partition INT")
     val dataSource = createUserDefinedPythonDataSource(
-      name = "MyDataSource", pythonScript = dataSourceScript, schema = schema)
+      name = dataSourceName, pythonScript = dataSourceScript)
     val err = intercept[AnalysisException] {
-      dataSource(spark).collect()
+      dataSource(spark, dataSourceName, userSpecifiedSchema = 
Some(schema)).collect()
     }
     assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON")
     assert(err.getMessage.contains("PYTHON_DATA_SOURCE_CREATE_ERROR"))
@@ -92,17 +194,18 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("data source assertion error") {
+    assume(shouldTestPythonUDFs)
     val dataSourceScript =
-      """
-        |class MyDataSource:
+      s"""
+        |class $dataSourceName:
         |   def __init__(self, options):
         |       ...
         |""".stripMargin
     val schema = StructType.fromDDL("id INT, partition INT")
     val dataSource = createUserDefinedPythonDataSource(
-      name = "MyDataSource", pythonScript = dataSourceScript, schema = schema)
+      name = dataSourceName, pythonScript = dataSourceScript)
     val err = intercept[AnalysisException] {
-      dataSource(spark).collect()
+      dataSource(spark, dataSourceName, userSpecifiedSchema = 
Some(schema)).collect()
     }
     assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON")
     assert(err.getMessage.contains("PYTHON_DATA_SOURCE_TYPE_MISMATCH"))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to