This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c8ad616b988 [SPARK-45600][PYTHON] Make Python data source registration
session level
c8ad616b988 is described below
commit c8ad616b988efdd47d7091f51c1e4563564b4e10
Author: allisonwang-db <[email protected]>
AuthorDate: Thu Nov 23 11:04:09 2023 +0900
[SPARK-45600][PYTHON] Make Python data source registration session level
### What changes were proposed in this pull request?
This PR makes dynamic Python data source registration session-scoped.
Previously, registered data sources were stored in the `sharedState` and can be
referenced by other sessions, which won't work with Spark Connect.
### Why are the changes needed?
To make Python data source support Spark Connect in the future.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New unit test
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43742 from allisonwang-db/spark-45600-session-level.
Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/tests/test_python_datasource.py | 30 +++++++++++++++++
.../org/apache/spark/sql/DataFrameReader.scala | 4 +--
.../scala/org/apache/spark/sql/SparkSession.scala | 2 +-
.../execution/datasources/DataSourceManager.scala | 15 ++++++---
.../sql/internal/BaseSessionStateBuilder.scala | 15 +++++++++
.../apache/spark/sql/internal/SessionState.scala | 5 +++
.../apache/spark/sql/internal/SharedState.scala | 12 -------
.../execution/python/PythonDataSourceSuite.scala | 38 +++++++++++++++++-----
8 files changed, 93 insertions(+), 28 deletions(-)
diff --git a/python/pyspark/sql/tests/test_python_datasource.py
b/python/pyspark/sql/tests/test_python_datasource.py
index 46b9fa642fd..bab062c4821 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -49,6 +49,36 @@ class BasePythonDataSourceTestsMixin:
self.assertEqual(list(reader.partitions()), [None])
self.assertEqual(list(reader.read(None)), [(None,)])
+ def test_data_source_register(self):
+ class TestReader(DataSourceReader):
+ def read(self, partition):
+ yield (0, 1)
+
+ class TestDataSource(DataSource):
+ def schema(self):
+ return "a INT, b INT"
+
+ def reader(self, schema):
+ return TestReader()
+
+ self.spark.dataSource.register(TestDataSource)
+ df = self.spark.read.format("TestDataSource").load()
+ assertDataFrameEqual(df, [Row(a=0, b=1)])
+
+ class MyDataSource(TestDataSource):
+ @classmethod
+ def name(cls):
+ return "TestDataSource"
+
+ def schema(self):
+ return "c INT, d INT"
+
+ # Should be able to register the data source with the same name.
+ self.spark.dataSource.register(MyDataSource)
+
+ df = self.spark.read.format("TestDataSource").load()
+ assertDataFrameEqual(df, [Row(c=0, d=1)])
+
def test_in_memory_data_source(self):
class InMemDataSourceReader(DataSourceReader):
DEFAULT_NUM_PARTITIONS: int = 3
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 7fadbbfac68..c29ffb32907 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -210,7 +210,7 @@ class DataFrameReader private[sql](sparkSession:
SparkSession) extends Logging {
}
val isUserDefinedDataSource =
- sparkSession.sharedState.dataSourceManager.dataSourceExists(source)
+ sparkSession.sessionState.dataSourceManager.dataSourceExists(source)
Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf))
match {
case Success(providerOpt) =>
@@ -243,7 +243,7 @@ class DataFrameReader private[sql](sparkSession:
SparkSession) extends Logging {
}
private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
- val builder =
sparkSession.sharedState.dataSourceManager.lookupDataSource(source)
+ val builder =
sparkSession.sessionState.dataSourceManager.lookupDataSource(source)
// Add `path` and `paths` options to the extra options if specified.
val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions,
paths: _*)
val plan = builder(sparkSession, source, userSpecifiedSchema,
optionsWithPath)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 5eba9e59c17..24497add04f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -233,7 +233,7 @@ class SparkSession private(
/**
* A collection of methods for registering user-defined data sources.
*/
- private[sql] def dataSource: DataSourceRegistration =
sharedState.dataSourceRegistration
+ private[sql] def dataSource: DataSourceRegistration =
sessionState.dataSourceRegistration
/**
* Returns a `StreamingQueryManager` that allows managing all the
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
index a8c9c892b8b..1cdc3d9cb69 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources
import java.util.Locale
import java.util.concurrent.ConcurrentHashMap
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
@@ -30,7 +31,7 @@ import org.apache.spark.sql.types.StructType
* A manager for user-defined data sources. It is used to register and lookup
data sources by
* their short names or fully qualified names.
*/
-class DataSourceManager {
+class DataSourceManager extends Logging {
private type DataSourceBuilder = (
SparkSession, // Spark session
@@ -49,10 +50,10 @@ class DataSourceManager {
*/
def registerDataSource(name: String, builder: DataSourceBuilder): Unit = {
val normalizedName = normalize(name)
- if (dataSourceBuilders.containsKey(normalizedName)) {
- throw QueryCompilationErrors.dataSourceAlreadyExists(name)
+ val previousValue = dataSourceBuilders.put(normalizedName, builder)
+ if (previousValue != null) {
+ logWarning(f"The data source $name replaced a previously registered data
source.")
}
- dataSourceBuilders.put(normalizedName, builder)
}
/**
@@ -73,4 +74,10 @@ class DataSourceManager {
def dataSourceExists(name: String): Boolean = {
dataSourceBuilders.containsKey(normalize(name))
}
+
+ override def clone(): DataSourceManager = {
+ val manager = new DataSourceManager
+ dataSourceBuilders.forEach((k, v) => manager.registerDataSource(k, v))
+ manager
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 630e1202f6d..d198e8f5d1f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -120,6 +120,13 @@ abstract class BaseSessionStateBuilder(
.getOrElse(extensions.registerTableFunctions(TableFunctionRegistry.builtin.clone()))
}
+ /**
+ * Manages the registration of data sources
+ */
+ protected lazy val dataSourceManager: DataSourceManager = {
+ parentState.map(_.dataSourceManager.clone()).getOrElse(new
DataSourceManager)
+ }
+
/**
* Experimental methods that can be used to define custom optimization rules
and custom planning
* strategies.
@@ -178,6 +185,12 @@ abstract class BaseSessionStateBuilder(
protected def udtfRegistration: UDTFRegistration = new
UDTFRegistration(tableFunctionRegistry)
+ /**
+ * A collection of method used for registering user-defined data sources.
+ */
+ protected def dataSourceRegistration: DataSourceRegistration =
+ new DataSourceRegistration(dataSourceManager)
+
/**
* Logical query plan analyzer for resolving unresolved attributes and
relations.
*
@@ -376,6 +389,8 @@ abstract class BaseSessionStateBuilder(
tableFunctionRegistry,
udfRegistration,
udtfRegistration,
+ dataSourceManager,
+ dataSourceRegistration,
() => catalog,
sqlParser,
() => analyzer,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index adf3e0cb6ca..bc6710e6cbd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder
+import org.apache.spark.sql.execution.datasources.DataSourceManager
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.util.ExecutionListenerManager
import org.apache.spark.util.{DependencyUtils, Utils}
@@ -49,6 +50,8 @@ import org.apache.spark.util.{DependencyUtils, Utils}
* @param udfRegistration Interface exposed to the user for registering
user-defined functions.
* @param udtfRegistration Interface exposed to the user for registering
user-defined
* table functions.
+ * @param dataSourceManager Internal catalog for managing data sources
registered by users.
+ * @param dataSourceRegistration Interface exposed to users for registering
data sources.
* @param catalogBuilder a function to create an internal catalog for managing
table and database
* states.
* @param sqlParser Parser that extracts expressions, plans, table identifiers
etc. from SQL texts.
@@ -73,6 +76,8 @@ private[sql] class SessionState(
val tableFunctionRegistry: TableFunctionRegistry,
val udfRegistration: UDFRegistration,
val udtfRegistration: UDTFRegistration,
+ val dataSourceManager: DataSourceManager,
+ val dataSourceRegistration: DataSourceRegistration,
catalogBuilder: () => SessionCatalog,
val sqlParser: ParserInterface,
analyzerBuilder: () => Analyzer,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index 8adc32fcf62..164710cdd88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -30,11 +30,9 @@ import org.apache.hadoop.fs.{FsUrlStreamHandlerFactory, Path}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.DataSourceRegistration
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.CacheManager
-import org.apache.spark.sql.execution.datasources.DataSourceManager
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.ui.{SQLAppStatusListener,
SQLAppStatusStore, SQLTab, StreamingQueryStatusStore}
import org.apache.spark.sql.internal.StaticSQLConf._
@@ -107,16 +105,6 @@ private[sql] class SharedState(
@GuardedBy("activeQueriesLock")
private[sql] val activeStreamingQueries = new ConcurrentHashMap[UUID,
StreamExecution]()
- /**
- * A data source manager shared by all sessions.
- */
- lazy val dataSourceManager = new DataSourceManager()
-
- /**
- * A collection of method used for registering user-defined data sources.
- */
- lazy val dataSourceRegistration = new
DataSourceRegistration(dataSourceManager)
-
/**
* A status store to query SQL status/metrics of this Spark application,
based on SQL-specific
* [[org.apache.spark.scheduler.SparkListenerEvent]]s.
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index bd0b08cbec8..33b34b39ab2 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils,
QueryTest, Row}
import org.apache.spark.sql.catalyst.plans.logical.{BatchEvalPythonUDTF,
PythonDataSourcePartitions}
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
@@ -143,16 +144,35 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
spark.dataSource.registerPython(dataSourceName, dataSource)
-
assert(spark.sharedState.dataSourceManager.dataSourceExists(dataSourceName))
+
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+ val ds1 =
spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName)
+ checkAnswer(
+ ds1(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)),
+ Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1)))
- // Check error when registering a data source with the same name.
- val err = intercept[AnalysisException] {
- spark.dataSource.registerPython(dataSourceName, dataSource)
- }
- checkError(
- exception = err,
- errorClass = "DATA_SOURCE_ALREADY_EXISTS",
- parameters = Map("provider" -> dataSourceName))
+ // Should be able to override an already registered data source.
+ val newScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource, DataSourceReader
+ |class SimpleDataSourceReader(DataSourceReader):
+ | def read(self, partition):
+ | yield (0, )
+ |
+ |class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "id INT"
+ |
+ | def reader(self, schema):
+ | return SimpleDataSourceReader()
+ |""".stripMargin
+ val newDataSource = createUserDefinedPythonDataSource(dataSourceName,
newScript)
+ spark.dataSource.registerPython(dataSourceName, newDataSource)
+
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+
+ val ds2 =
spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName)
+ checkAnswer(
+ ds2(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)),
+ Seq(Row(0)))
}
test("load data source") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]