This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9d93b7112a31 [SPARK-45639][SQL][PYTHON] Support loading Python data sources in DataFrameReader 9d93b7112a31 is described below commit 9d93b7112a31965447a34301889f90d14578e628 Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Wed Nov 8 09:23:12 2023 -0800 [SPARK-45639][SQL][PYTHON] Support loading Python data sources in DataFrameReader ### What changes were proposed in this pull request? This PR supports `spark.read.format(...).load()` for Python data sources. After this PR, users can use a Python data source directly like this: ```python from pyspark.sql.datasource import DataSource, DataSourceReader class MyReader(DataSourceReader): def read(self, partition): yield (0, 1) class MyDataSource(DataSource): classmethod def name(cls): return "my-source" def schema(self): return "id INT, value INT" def reader(self, schema): return MyReader() spark.dataSource.register(MyDataSource) df = spark.read.format("my-source").load() df.show() +---+-----+ | id|value| +---+-----+ | 0| 1| +---+-----+ ``` ### Why are the changes needed? To support Python data sources. ### Does this PR introduce _any_ user-facing change? Yes. After this PR, users can load a custom Python data source using `spark.read.format(...).load()`. ### How was this patch tested? New unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43630 from allisonwang-db/spark-45639-ds-lookup. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../src/main/resources/error/error-classes.json | 12 +++ dev/sparktestsupport/modules.py | 1 + docs/sql-error-conditions.md | 12 +++ python/pyspark/sql/session.py | 4 + python/pyspark/sql/tests/test_python_datasource.py | 97 ++++++++++++++++++++-- python/pyspark/sql/worker/create_data_source.py | 16 +++- .../spark/sql/errors/QueryCompilationErrors.scala | 12 +++ .../org/apache/spark/sql/DataFrameReader.scala | 48 +++++++++-- .../execution/datasources/DataSourceManager.scala | 31 ++++++- .../python/UserDefinedPythonDataSource.scala | 15 ++-- .../execution/python/PythonDataSourceSuite.scala | 35 ++++++++ 11 files changed, 255 insertions(+), 28 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index db46ee8ca208..c38171c3d9e6 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -850,6 +850,12 @@ ], "sqlState" : "42710" }, + "DATA_SOURCE_NOT_EXIST" : { + "message" : [ + "Data source '<provider>' not found. Please make sure the data source is registered." + ], + "sqlState" : "42704" + }, "DATA_SOURCE_NOT_FOUND" : { "message" : [ "Failed to find the data source: <provider>. Please find packages at `https://spark.apache.org/third-party-projects.html`." @@ -1095,6 +1101,12 @@ ], "sqlState" : "42809" }, + "FOUND_MULTIPLE_DATA_SOURCES" : { + "message" : [ + "Detected multiple data sources with the name '<provider>'. Please check the data source isn't simultaneously registered and located in the classpath." + ], + "sqlState" : "42710" + }, "GENERATED_COLUMN_WITH_DEFAULT_VALUE" : { "message" : [ "A column cannot have both a default value and a generation expression but column <colName> has default value: (<defaultValue>) and generation expression: (<genExpr>)." diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 95c9069a8313..01757ba28dd2 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -511,6 +511,7 @@ pyspark_sql = Module( "pyspark.sql.tests.pandas.test_pandas_udf_window", "pyspark.sql.tests.pandas.test_converter", "pyspark.sql.tests.test_pandas_sqlmetrics", + "pyspark.sql.tests.test_python_datasource", "pyspark.sql.tests.test_readwriter", "pyspark.sql.tests.test_serde", "pyspark.sql.tests.test_session", diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 7b0bc8ceb2b5..8a5faa15dc9c 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -454,6 +454,12 @@ DataType `<type>` requires a length parameter, for example `<type>`(10). Please Data source '`<provider>`' already exists in the registry. Please use a different name for the new data source. +### DATA_SOURCE_NOT_EXIST + +[SQLSTATE: 42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Data source '`<provider>`' not found. Please make sure the data source is registered. + ### DATA_SOURCE_NOT_FOUND [SQLSTATE: 42K02](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) @@ -669,6 +675,12 @@ No such struct field `<fieldName>` in `<fields>`. The operation `<statement>` is not allowed on the `<objectType>`: `<objectName>`. +### FOUND_MULTIPLE_DATA_SOURCES + +[SQLSTATE: 42710](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Detected multiple data sources with the name '`<provider>`'. Please check the data source isn't simultaneously registered and located in the classpath. + ### GENERATED_COLUMN_WITH_DEFAULT_VALUE [SQLSTATE: 42623](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 4ab7281d7ac8..85aff09aa3df 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -884,6 +884,10 @@ class SparkSession(SparkConversionMixin): Returns ------- :class:`DataSourceRegistration` + + Notes + ----- + This feature is experimental and unstable. """ from pyspark.sql.datasource import DataSourceRegistration diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index b429d73fb7d7..fe6a84175274 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -14,10 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os import unittest from pyspark.sql.datasource import DataSource, DataSourceReader +from pyspark.sql.types import Row +from pyspark.testing import assertDataFrameEqual from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing.utils import SPARK_HOME class BasePythonDataSourceTestsMixin: @@ -45,16 +49,93 @@ class BasePythonDataSourceTestsMixin: self.assertEqual(list(reader.partitions()), [None]) self.assertEqual(list(reader.read(None)), [(None,)]) - def test_register_data_source(self): - class MyDataSource(DataSource): - ... + def test_in_memory_data_source(self): + class InMemDataSourceReader(DataSourceReader): + DEFAULT_NUM_PARTITIONS: int = 3 + + def __init__(self, paths, options): + self.paths = paths + self.options = options + + def partitions(self): + if "num_partitions" in self.options: + num_partitions = int(self.options["num_partitions"]) + else: + num_partitions = self.DEFAULT_NUM_PARTITIONS + return range(num_partitions) + + def read(self, partition): + yield partition, str(partition) + + class InMemoryDataSource(DataSource): + @classmethod + def name(cls): + return "memory" + + def schema(self): + return "x INT, y STRING" + + def reader(self, schema) -> "DataSourceReader": + return InMemDataSourceReader(self.paths, self.options) + + self.spark.dataSource.register(InMemoryDataSource) + df = self.spark.read.format("memory").load() + self.assertEqual(df.rdd.getNumPartitions(), 3) + assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1"), Row(x=2, y="2")]) - self.spark.dataSource.register(MyDataSource) + df = self.spark.read.format("memory").option("num_partitions", 2).load() + assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")]) + self.assertEqual(df.rdd.getNumPartitions(), 2) + + def test_custom_json_data_source(self): + import json + + class JsonDataSourceReader(DataSourceReader): + def __init__(self, paths, options): + self.paths = paths + self.options = options + + def partitions(self): + return iter(self.paths) + + def read(self, path): + with open(path, "r") as file: + for line in file.readlines(): + if line.strip(): + data = json.loads(line) + yield data.get("name"), data.get("age") + + class JsonDataSource(DataSource): + @classmethod + def name(cls): + return "my-json" + + def schema(self): + return "name STRING, age INT" + + def reader(self, schema) -> "DataSourceReader": + return JsonDataSourceReader(self.paths, self.options) + + self.spark.dataSource.register(JsonDataSource) + path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") + path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json") + df1 = self.spark.read.format("my-json").load(path1) + self.assertEqual(df1.rdd.getNumPartitions(), 1) + assertDataFrameEqual( + df1, + [Row(name="Michael", age=None), Row(name="Andy", age=30), Row(name="Justin", age=19)], + ) - self.assertTrue( - self.spark._jsparkSession.sharedState() - .dataSourceRegistry() - .dataSourceExists("MyDataSource") + df2 = self.spark.read.format("my-json").load([path1, path2]) + self.assertEqual(df2.rdd.getNumPartitions(), 2) + assertDataFrameEqual( + df2, + [ + Row(name="Michael", age=None), + Row(name="Andy", age=30), + Row(name="Justin", age=19), + Row(name="Jonathan", age=None), + ], ) diff --git a/python/pyspark/sql/worker/create_data_source.py b/python/pyspark/sql/worker/create_data_source.py index ea56d2cc7522..6a9ef79b7c18 100644 --- a/python/pyspark/sql/worker/create_data_source.py +++ b/python/pyspark/sql/worker/create_data_source.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import inspect import os import sys from typing import IO, List from pyspark.accumulators import _accumulatorRegistry -from pyspark.errors import PySparkAssertionError, PySparkRuntimeError +from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, PySparkTypeError from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( read_bool, @@ -84,8 +84,20 @@ def main(infile: IO, outfile: IO) -> None: }, ) + # Check the name method is a class method. + if not inspect.ismethod(data_source_cls.name): + raise PySparkTypeError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "'name()' method to be a classmethod", + "actual": f"'{type(data_source_cls.name).__name__}'", + }, + ) + # Receive the provider name. provider = utf8_deserializer.loads(infile) + + # Check if the provider name matches the data source's name. if provider.lower() != data_source_cls.name().lower(): raise PySparkAssertionError( error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 1925eddd2ce2..0c5dcb1ead01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3805,4 +3805,16 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat errorClass = "DATA_SOURCE_ALREADY_EXISTS", messageParameters = Map("provider" -> name)) } + + def dataSourceDoesNotExist(name: String): Throwable = { + new AnalysisException( + errorClass = "DATA_SOURCE_NOT_EXIST", + messageParameters = Map("provider" -> name)) + } + + def foundMultipleDataSources(provider: String): Throwable = { + new AnalysisException( + errorClass = "FOUND_MULTIPLE_DATA_SOURCES", + messageParameters = Map("provider" -> provider)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 9992d8cbba07..ef447e8a8010 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql -import java.util.{Locale, Properties} +import java.util.{Locale, Properties, ServiceConfigurationError} import scala.jdk.CollectionConverters._ +import scala.util.{Failure, Success, Try} -import org.apache.spark.Partition +import org.apache.spark.{Partition, SparkClassNotFoundException, SparkThrowable} import org.apache.spark.annotation.Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -208,10 +209,45 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError() } - DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf).flatMap { provider => - DataSourceV2Utils.loadV2Source(sparkSession, provider, userSpecifiedSchema, extraOptions, - source, paths: _*) - }.getOrElse(loadV1Source(paths: _*)) + val isUserDefinedDataSource = + sparkSession.sharedState.dataSourceManager.dataSourceExists(source) + + Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match { + case Success(providerOpt) => + // The source can be successfully loaded as either a V1 or a V2 data source. + // Check if it is also a user-defined data source. + if (isUserDefinedDataSource) { + throw QueryCompilationErrors.foundMultipleDataSources(source) + } + providerOpt.flatMap { provider => + DataSourceV2Utils.loadV2Source( + sparkSession, provider, userSpecifiedSchema, extraOptions, source, paths: _*) + }.getOrElse(loadV1Source(paths: _*)) + case Failure(exception) => + // Exceptions are thrown while trying to load the data source as a V1 or V2 data source. + // For the following not found exceptions, if the user-defined data source is defined, + // we can instead return the user-defined data source. + val isNotFoundError = exception match { + case _: NoClassDefFoundError | _: SparkClassNotFoundException => true + case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND" + case e: ServiceConfigurationError => e.getCause.isInstanceOf[NoClassDefFoundError] + case _ => false + } + if (isNotFoundError && isUserDefinedDataSource) { + loadUserDefinedDataSource(paths) + } else { + // Throw the original exception. + throw exception + } + } + } + + private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = { + val builder = sparkSession.sharedState.dataSourceManager.lookupDataSource(source) + // Unless the legacy path option behavior is enabled, the extraOptions here + // should not include "path" or "paths" as keys. + val plan = builder(sparkSession, source, paths, userSpecifiedSchema, extraOptions) + Dataset.ofRows(sparkSession, plan) } private def loadV1Source(paths: String*) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index 283ca2ac62ed..72a9e6497aca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -22,10 +22,14 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.CaseInsensitiveStringMap +/** + * A manager for user-defined data sources. It is used to register and lookup data sources by + * their short names or fully qualified names. + */ class DataSourceManager { private type DataSourceBuilder = ( @@ -33,22 +37,41 @@ class DataSourceManager { String, // provider name Seq[String], // paths Option[StructType], // user specified schema - CaseInsensitiveStringMap // options + CaseInsensitiveMap[String] // options ) => LogicalPlan private val dataSourceBuilders = new ConcurrentHashMap[String, DataSourceBuilder]() private def normalize(name: String): String = name.toLowerCase(Locale.ROOT) + /** + * Register a data source builder for the given provider. + * Note that the provider name is case-insensitive. + */ def registerDataSource(name: String, builder: DataSourceBuilder): Unit = { val normalizedName = normalize(name) if (dataSourceBuilders.containsKey(normalizedName)) { throw QueryCompilationErrors.dataSourceAlreadyExists(name) } - // TODO(SPARK-45639): check if the data source is a DSv1 or DSv2 using loadDataSource. dataSourceBuilders.put(normalizedName, builder) } - def dataSourceExists(name: String): Boolean = + /** + * Returns a data source builder for the given provider and throw an exception if + * it does not exist. + */ + def lookupDataSource(name: String): DataSourceBuilder = { + if (dataSourceExists(name)) { + dataSourceBuilders.get(normalize(name)) + } else { + throw QueryCompilationErrors.dataSourceDoesNotExist(name) + } + } + + /** + * Checks if a data source with the specified name exists (case-insensitive). + */ + def dataSourceExists(name: String): Boolean = { dataSourceBuilders.containsKey(normalize(name)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index dbff8eefcd5f..703c1e10ce26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.python import java.io.{DataInputStream, DataOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.jdk.CollectionConverters._ import net.razorvine.pickle.Pickler @@ -28,9 +27,9 @@ import org.apache.spark.api.python.{PythonFunction, PythonWorkerUtils, SimplePyt import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PythonDataSource} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A user-defined Python data source. This is used by the Python API. @@ -44,7 +43,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { provider: String, paths: Seq[String], userSpecifiedSchema: Option[StructType], - options: CaseInsensitiveStringMap): LogicalPlan = { + options: CaseInsensitiveMap[String]): LogicalPlan = { val runner = new UserDefinedPythonDataSourceRunner( dataSourceCls, provider, paths, userSpecifiedSchema, options) @@ -70,7 +69,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { provider: String, paths: Seq[String] = Seq.empty, userSpecifiedSchema: Option[StructType] = None, - options: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty): DataFrame = { + options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)): DataFrame = { val plan = builder(sparkSession, provider, paths, userSpecifiedSchema, options) Dataset.ofRows(sparkSession, plan) } @@ -91,7 +90,7 @@ class UserDefinedPythonDataSourceRunner( provider: String, paths: Seq[String], userSpecifiedSchema: Option[StructType], - options: CaseInsensitiveStringMap) + options: CaseInsensitiveMap[String]) extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) { override val workerModule = "pyspark.sql.worker.create_data_source" @@ -113,9 +112,9 @@ class UserDefinedPythonDataSourceRunner( // Send the options dataOut.writeInt(options.size) - options.entrySet.asScala.foreach { e => - PythonWorkerUtils.writeUTF(e.getKey, dataOut) - PythonWorkerUtils.writeUTF(e.getValue, dataOut) + options.iterator.foreach { case (key, value) => + PythonWorkerUtils.writeUTF(key, dataOut) + PythonWorkerUtils.writeUTF(value, dataOut) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index 6c749c2c9b67..22a1e5250cd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -155,6 +155,41 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { parameters = Map("provider" -> dataSourceName)) } + test("load data source") { + assume(shouldTestPythonUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource, DataSourceReader + |class SimpleDataSourceReader(DataSourceReader): + | def __init__(self, paths, options): + | self.paths = paths + | self.options = options + | + | def partitions(self): + | return iter(self.paths) + | + | def read(self, path): + | yield (path, 1) + | + |class $dataSourceName(DataSource): + | @classmethod + | def name(cls) -> str: + | return "test" + | + | def schema(self) -> str: + | return "id STRING, value INT" + | + | def reader(self, schema): + | return SimpleDataSourceReader(self.paths, self.options) + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython("test", dataSource) + + checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1))) + checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1))) + checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), Row("2", 1))) + } + test("reader not implemented") { assume(shouldTestPythonUDFs) val dataSourceScript = --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org