This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6cb0cd5bcb5 [SPARK-45713][PYTHON] Support registering Python data
sources
6cb0cd5bcb5 is described below
commit 6cb0cd5bcb502b07a6388ce4d831e152e90e534d
Author: allisonwang-db <[email protected]>
AuthorDate: Tue Oct 31 13:08:06 2023 +0800
[SPARK-45713][PYTHON] Support registering Python data sources
### What changes were proposed in this pull request?
This PR adds support for registering Python data sources.
Users can register a Python data source using the class:
```python
class MyDataSource(DataSource):
...
spark.dataSource.register(MyDataSource)
```
This will allow users to use the data source using its name (to be
supported in SPARK-45639)
```python
spark.read.format("MyDataSource").load()
```
The data sources registered are stored in `sharedState` and can be accessed
by all spark sessions.
### Why are the changes needed?
To support Python data source.
### Does this PR introduce _any_ user-facing change?
Yes. This PR adds a new API in spark session to support registering data
sources.
### How was this patch tested?
Unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43566 from allisonwang-db/spark-45713-ds-register.
Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/resources/error/error-classes.json | 8 +-
.../CheckConnectJvmClientCompatibility.scala | 2 +
python/pyspark/sql/datasource.py | 75 ++++++--
python/pyspark/sql/session.py | 15 ++
python/pyspark/sql/tests/test_python_datasource.py | 16 +-
python/pyspark/sql/worker/create_data_source.py | 190 +++++++++++++++++++++
.../spark/sql/errors/QueryCompilationErrors.scala | 16 +-
.../apache/spark/sql/DataSourceRegistration.scala | 48 ++++++
.../scala/org/apache/spark/sql/SparkSession.scala | 5 +
.../execution/datasources/DataSourceManager.scala | 54 ++++++
.../python/UserDefinedPythonDataSource.scala | 129 ++++++++++++--
.../apache/spark/sql/internal/SharedState.scala | 12 ++
.../apache/spark/sql/IntegratedUDFTestUtils.scala | 13 +-
.../execution/python/PythonDataSourceSuite.scala | 151 +++++++++++++---
14 files changed, 672 insertions(+), 62 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-classes.json
b/common/utils/src/main/resources/error/error-classes.json
index f9cc0a86521..278011b8cc8 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -812,6 +812,12 @@
],
"sqlState" : "42K01"
},
+ "DATA_SOURCE_ALREADY_EXISTS" : {
+ "message" : [
+ "Data source '<provider>' already exists in the registry. Please use a
different name for the new data source."
+ ],
+ "sqlState" : "42710"
+ },
"DATA_SOURCE_NOT_FOUND" : {
"message" : [
"Failed to find the data source: <provider>. Please find packages at
`https://spark.apache.org/third-party-projects.html`."
@@ -2774,7 +2780,7 @@
},
"PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON" : {
"message" : [
- "Failed to plan Python data source <type> in Python: <msg>"
+ "Failed to <action> Python data source <type> in Python: <msg>"
],
"sqlState" : "38000"
},
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 7ddb339b12d..fb4f80998fc 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -177,6 +177,7 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.SparkSessionExtensionsProvider"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDTFRegistration"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UDFRegistration$"),
+
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataSourceRegistration"),
// DataFrame Reader & Writer
ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameReader.json"),
// rdd
@@ -227,6 +228,7 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.listenerManager"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.experimental"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.udtf"),
+
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.dataSource"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataFrame"),
ProblemFilters.exclude[Problem](
"org.apache.spark.sql.SparkSession.baseRelationToDataFrame"),
diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index 29dc194b7f8..5cda6596b3f 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -15,16 +15,18 @@
# limitations under the License.
#
from abc import ABC, abstractmethod
-from typing import Any, Dict, Iterator, Tuple, Union, TYPE_CHECKING
+from typing import final, Any, Dict, Iterator, List, Optional, Tuple, Type,
Union, TYPE_CHECKING
+from pyspark import since
from pyspark.sql import Row
from pyspark.sql.types import StructType
if TYPE_CHECKING:
from pyspark.sql._typing import OptionalPrimitiveType
+ from pyspark.sql.session import SparkSession
-__all__ = ["DataSource", "DataSourceReader"]
+__all__ = ["DataSource", "DataSourceReader", "DataSourceRegistration"]
class DataSource(ABC):
@@ -41,23 +43,35 @@ class DataSource(ABC):
``spark.read.format(...).load()`` and save data using
``df.write.format(...).save()``.
"""
- def __init__(self, options: Dict[str, "OptionalPrimitiveType"]):
+ @final
+ def __init__(
+ self,
+ paths: List[str],
+ userSpecifiedSchema: Optional[StructType],
+ options: Dict[str, "OptionalPrimitiveType"],
+ ) -> None:
"""
- Initializes the data source with user-provided options.
+ Initializes the data source with user-provided information.
Parameters
----------
+ paths : list
+ A list of paths to the data source.
+ userSpecifiedSchema : StructType, optional
+ The user-specified schema of the data source.
options : dict
A dictionary representing the options for this data source.
Notes
-----
- This method should not contain any non-serializable objects.
+ This method should not be overridden.
"""
+ self.paths = paths
+ self.userSpecifiedSchema = userSpecifiedSchema
self.options = options
- @property
- def name(self) -> str:
+ @classmethod
+ def name(cls) -> str:
"""
Returns a string represents the format name of this data source.
@@ -66,20 +80,21 @@ class DataSource(ABC):
Examples
--------
- >>> def name(self):
+ >>> def name(cls):
... return "my_data_source"
"""
- return self.__class__.__name__
+ return cls.__name__
def schema(self) -> Union[StructType, str]:
"""
Returns the schema of the data source.
- It can reference the ``options`` field to infer the data source's
schema when
- users do not explicitly specify it. This method is invoked once when
calling
- ``spark.read.format(...).load()`` to get the schema for a data source
read
- operation. If this method is not implemented, and a user does not
provide a
- schema when reading the data source, an exception will be thrown.
+ It can refer any field initialized in the ``__init__`` method to infer
the
+ data source's schema when users do not explicitly specify it. This
method is
+ invoked once when calling ``spark.read.format(...).load()`` to get the
schema
+ for a data source read operation. If this method is not implemented,
and a
+ user does not provide a schema when reading the data source, an
exception will
+ be thrown.
Returns
-------
@@ -212,3 +227,35 @@ class DataSourceReader(ABC):
... yield Row(partition=partition, value=1)
"""
...
+
+
+@since(4.0)
+class DataSourceRegistration:
+ """
+ Wrapper for data source registration. This instance can be accessed by
+ :attr:`spark.dataSource`.
+ """
+
+ def __init__(self, sparkSession: "SparkSession"):
+ self.sparkSession = sparkSession
+
+ def register(
+ self,
+ dataSource: Type["DataSource"],
+ ) -> None:
+ """Register a Python user-defined data source.
+
+ Parameters
+ ----------
+ dataSource : type
+ The data source class to be registered. It should be a subclass of
DataSource.
+ """
+ from pyspark.sql.udf import _wrap_function
+
+ name = dataSource.name()
+ sc = self.sparkSession.sparkContext
+ # Serialize the data source class.
+ wrapped = _wrap_function(sc, dataSource)
+ assert sc._jvm is not None
+ ds =
sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonDataSource(wrapped)
+ self.sparkSession._jsparkSession.dataSource().registerPython(name, ds)
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 5c07ca607b7..1ffb602ba86 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -74,6 +74,7 @@ if TYPE_CHECKING:
from pyspark.sql.streaming import StreamingQueryManager
from pyspark.sql.udf import UDFRegistration
from pyspark.sql.udtf import UDTFRegistration
+ from pyspark.sql.datasource import DataSourceRegistration
# Running MyPy type checks will always require pandas and
# other dependencies so importing here is fine.
@@ -863,6 +864,20 @@ class SparkSession(SparkConversionMixin):
return UDTFRegistration(self)
+ @property
+ def dataSource(self) -> "DataSourceRegistration":
+ """Returns a :class:`DataSourceRegistration` for data source
registration.
+
+ .. versionadded:: 4.0.0
+
+ Returns
+ -------
+ :class:`DataSourceRegistration`
+ """
+ from pyspark.sql.datasource import DataSourceRegistration
+
+ return DataSourceRegistration(self)
+
def range(
self,
start: int,
diff --git a/python/pyspark/sql/tests/test_python_datasource.py
b/python/pyspark/sql/tests/test_python_datasource.py
index 3ac103d2c7c..6584312dda4 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -26,9 +26,9 @@ class BasePythonDataSourceTestsMixin:
...
options = dict(a=1, b=2)
- ds = MyDataSource(options)
+ ds = MyDataSource(paths=[], userSpecifiedSchema=None, options=options)
self.assertEqual(ds.options, options)
- self.assertEqual(ds.name, "MyDataSource")
+ self.assertEqual(ds.name(), "MyDataSource")
with self.assertRaises(NotImplementedError):
ds.schema()
with self.assertRaises(NotImplementedError):
@@ -43,6 +43,18 @@ class BasePythonDataSourceTestsMixin:
self.assertEqual(list(reader.partitions()), [None])
self.assertEqual(list(reader.read(None)), [(None,)])
+ def test_register_data_source(self):
+ class MyDataSource(DataSource):
+ ...
+
+ self.spark.dataSource.register(MyDataSource)
+
+ self.assertTrue(
+ self.spark._jsparkSession.sharedState()
+ .dataSourceRegistry()
+ .dataSourceExists("MyDataSource")
+ )
+
class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
...
diff --git a/python/pyspark/sql/worker/create_data_source.py
b/python/pyspark/sql/worker/create_data_source.py
new file mode 100644
index 00000000000..ea56d2cc752
--- /dev/null
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -0,0 +1,190 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+from typing import IO, List
+
+from pyspark.accumulators import _accumulatorRegistry
+from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+ read_bool,
+ read_int,
+ write_int,
+ write_with_length,
+ SpecialLengths,
+)
+from pyspark.sql.datasource import DataSource
+from pyspark.sql.types import _parse_datatype_json_string, StructType
+from pyspark.util import handle_worker_exception
+from pyspark.worker_util import (
+ check_python_version,
+ read_command,
+ pickleSer,
+ send_accumulator_updates,
+ setup_broadcasts,
+ setup_memory_limits,
+ setup_spark_files,
+ utf8_deserializer,
+)
+
+
+def main(infile: IO, outfile: IO) -> None:
+ """
+ Main method for creating a Python data source instance.
+
+ This process is invoked from the
`UserDefinedPythonDataSourceRunner.runInPython` method
+ in JVM. This process is responsible for creating a `DataSource` object and
send the
+ information needed back to the JVM.
+
+ The JVM sends the following information to this process:
+ - a `DataSource` class representing the data source to be created.
+ - a provider name in string.
+ - a list of paths in string.
+ - an optional user-specified schema in json string.
+ - a dictionary of options in string.
+
+ This process then creates a `DataSource` instance using the above
information and
+ sends the pickled instance as well as the schema back to the JVM.
+ """
+ try:
+ check_python_version(infile)
+
+ memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB",
"-1"))
+ setup_memory_limits(memory_limit_mb)
+
+ setup_spark_files(infile)
+ setup_broadcasts(infile)
+
+ _accumulatorRegistry.clear()
+
+ # Receive the data source class.
+ data_source_cls = read_command(pickleSer, infile)
+ if not (isinstance(data_source_cls, type) and
issubclass(data_source_cls, DataSource)):
+ raise PySparkAssertionError(
+ error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ message_parameters={
+ "expected": "a subclass of DataSource",
+ "actual": f"'{type(data_source_cls).__name__}'",
+ },
+ )
+
+ # Receive the provider name.
+ provider = utf8_deserializer.loads(infile)
+ if provider.lower() != data_source_cls.name().lower():
+ raise PySparkAssertionError(
+ error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ message_parameters={
+ "expected": f"provider with name {data_source_cls.name()}",
+ "actual": f"'{provider}'",
+ },
+ )
+
+ # Receive the paths.
+ num_paths = read_int(infile)
+ paths: List[str] = []
+ for _ in range(num_paths):
+ paths.append(utf8_deserializer.loads(infile))
+
+ # Receive the user-specified schema
+ user_specified_schema = None
+ if read_bool(infile):
+ user_specified_schema =
_parse_datatype_json_string(utf8_deserializer.loads(infile))
+ if not isinstance(user_specified_schema, StructType):
+ raise PySparkAssertionError(
+ error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ message_parameters={
+ "expected": "the user-defined schema to be a
'StructType'",
+ "actual": f"'{type(data_source_cls).__name__}'",
+ },
+ )
+
+ # Receive the options.
+ options = dict()
+ num_options = read_int(infile)
+ for _ in range(num_options):
+ key = utf8_deserializer.loads(infile)
+ value = utf8_deserializer.loads(infile)
+ options[key] = value
+
+ # Instantiate a data source.
+ try:
+ data_source = data_source_cls(
+ paths=paths,
+ userSpecifiedSchema=user_specified_schema, # type: ignore
+ options=options,
+ )
+ except Exception as e:
+ raise PySparkRuntimeError(
+ error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
+ message_parameters={"type": "instance", "error": str(e)},
+ )
+
+ # Get the schema of the data source.
+ # If user_specified_schema is not None, use user_specified_schema.
+ # Otherwise, use the schema of the data source.
+ # Throw exception if the data source does not implement schema().
+ is_ddl_string = False
+ if user_specified_schema is None:
+ try:
+ schema = data_source.schema()
+ if isinstance(schema, str):
+ # Here we cannot use _parse_datatype_string to parse the
DDL string schema.
+ # as it requires an active Spark session.
+ is_ddl_string = True
+ except NotImplementedError:
+ raise PySparkRuntimeError(
+ error_class="PYTHON_DATA_SOURCE_METHOD_NOT_IMPLEMENTED",
+ message_parameters={"type": "instance", "method":
"schema"},
+ )
+ else:
+ schema = user_specified_schema # type: ignore
+
+ assert schema is not None
+
+ # Return the pickled data source instance.
+ pickleSer._write_with_length(data_source, outfile)
+
+ # Return the schema of the data source.
+ write_int(int(is_ddl_string), outfile)
+ if is_ddl_string:
+ write_with_length(schema.encode("utf-8"), outfile) # type: ignore
+ else:
+ write_with_length(schema.json().encode("utf-8"), outfile) # type:
ignore
+
+ except BaseException as e:
+ handle_worker_exception(e, outfile)
+ sys.exit(-1)
+
+ send_accumulator_updates(outfile)
+
+ # check end of stream
+ if read_int(infile) == SpecialLengths.END_OF_STREAM:
+ write_int(SpecialLengths.END_OF_STREAM, outfile)
+ else:
+ # write a different value to tell JVM to not reuse this worker
+ write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+ sys.exit(-1)
+
+
+if __name__ == "__main__":
+ # Read information about how to connect back to the JVM from the
environment.
+ java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
+ auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
+ (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+ main(sock_file, sock_file)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 5fcd71d8bf9..4b28eadfec6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -1174,10 +1174,14 @@ private[sql] object QueryCompilationErrors extends
QueryErrorsBase with Compilat
}
def schemaIsNotStructTypeError(exp: Expression, dataType: DataType):
Throwable = {
+ schemaIsNotStructTypeError(toSQLExpr(exp), dataType)
+ }
+
+ def schemaIsNotStructTypeError(inputSchema: String, dataType: DataType):
Throwable = {
new AnalysisException(
errorClass = "INVALID_SCHEMA.NON_STRUCT_TYPE",
messageParameters = Map(
- "inputSchema" -> toSQLExpr(exp),
+ "inputSchema" -> inputSchema,
"dataType" -> toSQLType(dataType)
))
}
@@ -1981,10 +1985,10 @@ private[sql] object QueryCompilationErrors extends
QueryErrorsBase with Compilat
)
}
- def failToPlanDataSourceError(tpe: String, msg: String): Throwable = {
+ def failToPlanDataSourceError(action: String, tpe: String, msg: String):
Throwable = {
new AnalysisException(
errorClass = "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON",
- messageParameters = Map("type" -> tpe, "msg" -> msg)
+ messageParameters = Map("action" -> action, "type" -> tpe, "msg" -> msg)
)
}
@@ -3801,4 +3805,10 @@ private[sql] object QueryCompilationErrors extends
QueryErrorsBase with Compilat
"functionName" -> functionName,
"reason" -> reason))
}
+
+ def dataSourceAlreadyExists(name: String): Throwable = {
+ new AnalysisException(
+ errorClass = "DATA_SOURCE_ALREADY_EXISTS",
+ messageParameters = Map("provider" -> name))
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
new file mode 100644
index 00000000000..15d26418984
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.Evolving
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.execution.datasources.DataSourceManager
+import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource
+
+/**
+ * Functions for registering user-defined data sources.
+ * Use `SparkSession.dataSource` to access this.
+ */
+@Evolving
+private[sql] class DataSourceRegistration private[sql] (dataSourceManager:
DataSourceManager)
+ extends Logging {
+
+ protected[sql] def registerPython(
+ name: String,
+ dataSource: UserDefinedPythonDataSource): Unit = {
+ log.debug(
+ s"""
+ | Registering new Python data source:
+ | name: $name
+ | command: ${dataSource.dataSourceCls.command}
+ | envVars: ${dataSource.dataSourceCls.envVars}
+ | pythonIncludes: ${dataSource.dataSourceCls.pythonIncludes}
+ | pythonExec: ${dataSource.dataSourceCls.pythonExec}
+ """.stripMargin)
+
+ dataSourceManager.registerDataSource(name, dataSource.builder)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 59aa17baa7f..aec40c845fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -228,6 +228,11 @@ class SparkSession private(
def udtf: UDTFRegistration = sessionState.udtfRegistration
+ /**
+ * A collection of methods for registering user-defined data sources.
+ */
+ private[sql] def dataSource: DataSourceRegistration =
sharedState.dataSourceRegistration
+
/**
* Returns a `StreamingQueryManager` that allows managing all the
* `StreamingQuery`s active on `this`.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
new file mode 100644
index 00000000000..283ca2ac62e
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import java.util.Locale
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class DataSourceManager {
+
+ private type DataSourceBuilder = (
+ SparkSession, // Spark session
+ String, // provider name
+ Seq[String], // paths
+ Option[StructType], // user specified schema
+ CaseInsensitiveStringMap // options
+ ) => LogicalPlan
+
+ private val dataSourceBuilders = new ConcurrentHashMap[String,
DataSourceBuilder]()
+
+ private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)
+
+ def registerDataSource(name: String, builder: DataSourceBuilder): Unit = {
+ val normalizedName = normalize(name)
+ if (dataSourceBuilders.containsKey(normalizedName)) {
+ throw QueryCompilationErrors.dataSourceAlreadyExists(name)
+ }
+ // TODO(SPARK-45639): check if the data source is a DSv1 or DSv2 using
loadDataSource.
+ dataSourceBuilders.put(normalizedName, builder)
+ }
+
+ def dataSourceExists(name: String): Boolean =
+ dataSourceBuilders.containsKey(normalize(name))
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index faae4101fa0..dbff8eefcd5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -20,25 +20,133 @@ package org.apache.spark.sql.execution.python
import java.io.{DataInputStream, DataOutputStream}
import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
import net.razorvine.pickle.Pickler
-import org.apache.spark.api.python.{PythonFunction, PythonWorkerUtils,
SpecialLengths}
+import org.apache.spark.api.python.{PythonFunction, PythonWorkerUtils,
SimplePythonFunction, SpecialLengths}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
-import org.apache.spark.sql.catalyst.plans.logical.PythonDataSource
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
PythonDataSource}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
* A user-defined Python data source. This is used by the Python API.
+ *
+ * @param dataSourceCls The Python data source class.
+ */
+case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
+
+ def builder(
+ sparkSession: SparkSession,
+ provider: String,
+ paths: Seq[String],
+ userSpecifiedSchema: Option[StructType],
+ options: CaseInsensitiveStringMap): LogicalPlan = {
+
+ val runner = new UserDefinedPythonDataSourceRunner(
+ dataSourceCls, provider, paths, userSpecifiedSchema, options)
+
+ val result = runner.runInPython()
+ val pickledDataSourceInstance = result.dataSource
+
+ val dataSource = SimplePythonFunction(
+ command = pickledDataSourceInstance,
+ envVars = dataSourceCls.envVars,
+ pythonIncludes = dataSourceCls.pythonIncludes,
+ pythonExec = dataSourceCls.pythonExec,
+ pythonVer = dataSourceCls.pythonVer,
+ broadcastVars = dataSourceCls.broadcastVars,
+ accumulator = dataSourceCls.accumulator)
+ val schema = result.schema
+
+ PythonDataSource(dataSource, schema, output = toAttributes(schema))
+ }
+
+ def apply(
+ sparkSession: SparkSession,
+ provider: String,
+ paths: Seq[String] = Seq.empty,
+ userSpecifiedSchema: Option[StructType] = None,
+ options: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty):
DataFrame = {
+ val plan = builder(sparkSession, provider, paths, userSpecifiedSchema,
options)
+ Dataset.ofRows(sparkSession, plan)
+ }
+}
+
+/**
+ * Used to store the result of creating a Python data source in the Python
process.
*/
-case class UserDefinedPythonDataSource(
- dataSource: PythonFunction,
- schema: StructType) {
- def apply(session: SparkSession): DataFrame = {
- val source = PythonDataSource(dataSource, schema, output =
toAttributes(schema))
- Dataset.ofRows(session, source)
+case class PythonDataSourceCreationResult(
+ dataSource: Array[Byte],
+ schema: StructType)
+
+/**
+ * A runner used to create a Python data source in a Python process and return
the result.
+ */
+class UserDefinedPythonDataSourceRunner(
+ dataSourceCls: PythonFunction,
+ provider: String,
+ paths: Seq[String],
+ userSpecifiedSchema: Option[StructType],
+ options: CaseInsensitiveStringMap)
+ extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) {
+
+ override val workerModule = "pyspark.sql.worker.create_data_source"
+
+ override protected def writeToPython(dataOut: DataOutputStream, pickler:
Pickler): Unit = {
+ // Send python data source
+ PythonWorkerUtils.writePythonFunction(dataSourceCls, dataOut)
+
+ // Send the provider name
+ PythonWorkerUtils.writeUTF(provider, dataOut)
+
+ // Send the paths
+ dataOut.writeInt(paths.length)
+ paths.foreach(PythonWorkerUtils.writeUTF(_, dataOut))
+
+ // Send the user-specified schema, if provided
+ dataOut.writeBoolean(userSpecifiedSchema.isDefined)
+ userSpecifiedSchema.map(_.json).foreach(PythonWorkerUtils.writeUTF(_,
dataOut))
+
+ // Send the options
+ dataOut.writeInt(options.size)
+ options.entrySet.asScala.foreach { e =>
+ PythonWorkerUtils.writeUTF(e.getKey, dataOut)
+ PythonWorkerUtils.writeUTF(e.getValue, dataOut)
+ }
+ }
+
+ override protected def receiveFromPython(
+ dataIn: DataInputStream): PythonDataSourceCreationResult = {
+ // Receive the pickled data source or an exception raised in Python worker.
+ val length = dataIn.readInt()
+ if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+ val msg = PythonWorkerUtils.readUTF(dataIn)
+ throw QueryCompilationErrors.failToPlanDataSourceError(
+ action = "create", tpe = "instance", msg = msg)
+ }
+
+ // Receive the pickled data source.
+ val pickledDataSourceInstance: Array[Byte] =
PythonWorkerUtils.readBytes(length, dataIn)
+
+ // Receive the schema.
+ val isDDLString = dataIn.readInt()
+ val schemaStr = PythonWorkerUtils.readUTF(dataIn)
+ val schema = if (isDDLString == 1) {
+ DataType.fromDDL(schemaStr)
+ } else {
+ DataType.fromJson(schemaStr)
+ }
+ if (!schema.isInstanceOf[StructType]) {
+ throw QueryCompilationErrors.schemaIsNotStructTypeError(schemaStr,
schema)
+ }
+
+ PythonDataSourceCreationResult(
+ dataSource = pickledDataSourceInstance,
+ schema = schema.asInstanceOf[StructType])
}
}
@@ -66,7 +174,8 @@ class UserDefinedPythonDataSourceReadRunner(
val length = dataIn.readInt()
if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
val msg = PythonWorkerUtils.readUTF(dataIn)
- throw QueryCompilationErrors.failToPlanDataSourceError("read", msg)
+ throw QueryCompilationErrors.failToPlanDataSourceError(
+ action = "plan", tpe = "read", msg = msg)
}
// Receive the pickled 'read' function.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index 164710cdd88..8adc32fcf62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -30,9 +30,11 @@ import org.apache.hadoop.fs.{FsUrlStreamHandlerFactory, Path}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.DataSourceRegistration
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.CacheManager
+import org.apache.spark.sql.execution.datasources.DataSourceManager
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.ui.{SQLAppStatusListener,
SQLAppStatusStore, SQLTab, StreamingQueryStatusStore}
import org.apache.spark.sql.internal.StaticSQLConf._
@@ -105,6 +107,16 @@ private[sql] class SharedState(
@GuardedBy("activeQueriesLock")
private[sql] val activeStreamingQueries = new ConcurrentHashMap[UUID,
StreamExecution]()
+ /**
+ * A data source manager shared by all sessions.
+ */
+ lazy val dataSourceManager = new DataSourceManager()
+
+ /**
+ * A collection of method used for registering user-defined data sources.
+ */
+ lazy val dataSourceRegistration = new
DataSourceRegistration(dataSourceManager)
+
/**
* A status store to query SQL status/metrics of this Spark application,
based on SQL-specific
* [[org.apache.spark.scheduler.SparkListenerEvent]]s.
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
index 309d25ede4f..eab3b7d81b8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
@@ -232,8 +232,8 @@ object IntegratedUDFTestUtils extends SQLHelper {
"from pyspark.serializers import CloudPickleSerializer; " +
s"f = open('$path', 'wb');" +
s"exec(open('$codePath', 'r').read());" +
- s"ds = $dataSourceName(options=dict());" +
- "f.write(CloudPickleSerializer().dumps(ds))"),
+ s"dataSourceCls = $dataSourceName;" +
+ "f.write(CloudPickleSerializer().dumps(dataSourceCls))"),
None,
"PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
binaryPythonDataSource = Files.readAllBytes(path.toPath)
@@ -425,19 +425,16 @@ object IntegratedUDFTestUtils extends SQLHelper {
def createUserDefinedPythonDataSource(
name: String,
- pythonScript: String,
- schema: StructType): UserDefinedPythonDataSource = {
+ pythonScript: String): UserDefinedPythonDataSource = {
UserDefinedPythonDataSource(
- dataSource = SimplePythonFunction(
+ dataSourceCls = SimplePythonFunction(
command = createPythonDataSource(name, pythonScript),
envVars = workerEnv.clone().asInstanceOf[java.util.Map[String,
String]],
pythonIncludes = List.empty[String].asJava,
pythonExec = pythonExec,
pythonVer = pythonVer,
broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
- accumulator = null),
- schema = schema
- )
+ accumulator = null))
}
def createUserDefinedPythonTableFunction(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index b038bcd5529..6c749c2c9b6 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -25,26 +25,34 @@ import org.apache.spark.sql.types.StructType
class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
import IntegratedUDFTestUtils._
+ private def dataSourceName = "SimpleDataSource"
+ private def simpleDataSourceReaderScript: String =
+ """
+ |class SimpleDataSourceReader(DataSourceReader):
+ | def partitions(self):
+ | return range(0, 2)
+ | def read(self, partition):
+ | yield (0, partition)
+ | yield (1, partition)
+ | yield (2, partition)
+ |""".stripMargin
+
test("simple data source") {
+ assume(shouldTestPythonUDFs)
val dataSourceScript =
- """
+ s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
- |class MyDataSourceReader(DataSourceReader):
- | def partitions(self):
- | return range(0, 2)
- | def read(self, partition):
- | yield (0, partition)
- | yield (1, partition)
- | yield (2, partition)
+ |$simpleDataSourceReaderScript
|
- |class MyDataSource(DataSource):
+ |class $dataSourceName(DataSource):
| def reader(self, schema):
- | return MyDataSourceReader()
+ | return SimpleDataSourceReader()
|""".stripMargin
val schema = StructType.fromDDL("id INT, partition INT")
val dataSource = createUserDefinedPythonDataSource(
- name = "MyDataSource", pythonScript = dataSourceScript, schema = schema)
- val df = dataSource(spark)
+ name = dataSourceName, pythonScript = dataSourceScript)
+ val df = dataSource.apply(
+ spark, provider = dataSourceName, userSpecifiedSchema = Some(schema))
assert(df.rdd.getNumPartitions == 2)
val plan = df.queryExecution.optimizedPlan
plan match {
@@ -55,36 +63,130 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0),
Row(2, 1)))
}
+ test("simple data source with string schema") {
+ assume(shouldTestPythonUDFs)
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource, DataSourceReader
+ |$simpleDataSourceReaderScript
+ |
+ |class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "id INT, partition INT"
+ |
+ | def reader(self, schema):
+ | return SimpleDataSourceReader()
+ |""".stripMargin
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
+ val df = dataSource(spark, provider = dataSourceName)
+ checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0),
Row(2, 1)))
+ }
+
+ test("simple data source with StructType schema") {
+ assume(shouldTestPythonUDFs)
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource, DataSourceReader
+ |from pyspark.sql.types import IntegerType, StructType, StructField
+ |$simpleDataSourceReaderScript
+ |
+ |class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return StructType([
+ | StructField("id", IntegerType()),
+ | StructField("partition", IntegerType())
+ | ])
+ |
+ | def reader(self, schema):
+ | return SimpleDataSourceReader()
+ |""".stripMargin
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
+ val df = dataSource(spark, provider = dataSourceName)
+ checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0),
Row(2, 1)))
+ }
+
+ test("data source with invalid schema") {
+ assume(shouldTestPythonUDFs)
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource, DataSourceReader
+ |$simpleDataSourceReaderScript
+ |
+ |class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "INT"
+ |
+ | def reader(self, schema):
+ | return SimpleDataSourceReader()
+ |""".stripMargin
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
+ checkError(
+ exception = intercept[AnalysisException](dataSource(spark, provider =
dataSourceName)),
+ errorClass = "INVALID_SCHEMA.NON_STRUCT_TYPE",
+ parameters = Map("inputSchema" -> "INT", "dataType" -> "\"INT\""))
+ }
+
+ test("register data source") {
+ assume(shouldTestPythonUDFs)
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource, DataSourceReader
+ |$simpleDataSourceReaderScript
+ |
+ |class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "id INT, partition INT"
+ |
+ | def reader(self, schema):
+ | return SimpleDataSourceReader()
+ |""".stripMargin
+
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+
assert(spark.sharedState.dataSourceManager.dataSourceExists(dataSourceName))
+
+ // Check error when registering a data source with the same name.
+ val err = intercept[AnalysisException] {
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+ }
+ checkError(
+ exception = err,
+ errorClass = "DATA_SOURCE_ALREADY_EXISTS",
+ parameters = Map("provider" -> dataSourceName))
+ }
+
test("reader not implemented") {
+ assume(shouldTestPythonUDFs)
val dataSourceScript =
- """
+ s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
- |class MyDataSource(DataSource):
+ |class $dataSourceName(DataSource):
| pass
|""".stripMargin
val schema = StructType.fromDDL("id INT, partition INT")
val dataSource = createUserDefinedPythonDataSource(
- name = "MyDataSource", pythonScript = dataSourceScript, schema = schema)
+ name = dataSourceName, pythonScript = dataSourceScript)
val err = intercept[AnalysisException] {
- dataSource(spark).collect()
+ dataSource(spark, dataSourceName, userSpecifiedSchema =
Some(schema)).collect()
}
assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON")
assert(err.getMessage.contains("PYTHON_DATA_SOURCE_METHOD_NOT_IMPLEMENTED"))
}
test("error creating reader") {
+ assume(shouldTestPythonUDFs)
val dataSourceScript =
- """
+ s"""
|from pyspark.sql.datasource import DataSource
- |class MyDataSource(DataSource):
+ |class $dataSourceName(DataSource):
| def reader(self, schema):
| raise Exception("error creating reader")
|""".stripMargin
val schema = StructType.fromDDL("id INT, partition INT")
val dataSource = createUserDefinedPythonDataSource(
- name = "MyDataSource", pythonScript = dataSourceScript, schema = schema)
+ name = dataSourceName, pythonScript = dataSourceScript)
val err = intercept[AnalysisException] {
- dataSource(spark).collect()
+ dataSource(spark, dataSourceName, userSpecifiedSchema =
Some(schema)).collect()
}
assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON")
assert(err.getMessage.contains("PYTHON_DATA_SOURCE_CREATE_ERROR"))
@@ -92,17 +194,18 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
}
test("data source assertion error") {
+ assume(shouldTestPythonUDFs)
val dataSourceScript =
- """
- |class MyDataSource:
+ s"""
+ |class $dataSourceName:
| def __init__(self, options):
| ...
|""".stripMargin
val schema = StructType.fromDDL("id INT, partition INT")
val dataSource = createUserDefinedPythonDataSource(
- name = "MyDataSource", pythonScript = dataSourceScript, schema = schema)
+ name = dataSourceName, pythonScript = dataSourceScript)
val err = intercept[AnalysisException] {
- dataSource(spark).collect()
+ dataSource(spark, dataSourceName, userSpecifiedSchema =
Some(schema)).collect()
}
assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON")
assert(err.getMessage.contains("PYTHON_DATA_SOURCE_TYPE_MISMATCH"))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]