This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a1b0da200b27 [SPARK-45597][PYTHON][SQL] Support creating table using a 
Python data source in SQL (DSv2 exec)
a1b0da200b27 is described below

commit a1b0da200b271214e9d6b3170308509d7d514c7f
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Fri Dec 15 11:04:32 2023 -0800

    [SPARK-45597][PYTHON][SQL] Support creating table using a Python data 
source in SQL (DSv2 exec)
    
    ### What changes were proposed in this pull request?
    
    This PR is same as https://github.com/apache/spark/pull/44233 but does not 
use `V1Table` but the original DSv2 interface by reusing UDTF execution code.
    
    ### Why are the changes needed?
    
    In order for Python Data Source to be able to be used in all other place 
including SparkR, Scala together.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Users can register their Python Data Source, and use them in SQL, 
SparkR, etc.
    
    ### How was this patch tested?
    
    Unittests were added, and manually tested.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44269
    Closes #44233
    Closes #43784
    
    Closes #44305 from HyukjinKwon/SPARK-45597-3.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../plans/logical/pythonLogicalOperators.scala     |  40 +----
 .../org/apache/spark/sql/DataFrameReader.scala     |  48 +----
 .../apache/spark/sql/DataSourceRegistration.scala  |   2 +-
 .../scala/org/apache/spark/sql/SparkSession.scala  |   2 +-
 .../spark/sql/execution/SparkOptimizer.scala       |   8 +-
 .../spark/sql/execution/SparkStrategies.scala      |   2 -
 .../apache/spark/sql/execution/command/ddl.scala   |   6 +-
 .../spark/sql/execution/command/tables.scala       |   7 +-
 .../sql/execution/datasources/DataSource.scala     |  35 +++-
 .../execution/datasources/DataSourceManager.scala  |  24 +--
 .../datasources/PlanPythonDataSourceScan.scala     |  89 ----------
 .../ApplyInPandasWithStatePythonRunner.scala       |   6 +-
 .../execution/python/ArrowEvalPythonUDTFExec.scala |   2 +-
 .../sql/execution/python/ArrowPythonRunner.scala   |   6 +-
 .../execution/python/ArrowPythonUDTFRunner.scala   |   2 +-
 .../python/CoGroupedArrowPythonRunner.scala        |   6 +-
 .../python/FlatMapGroupsInPythonExec.scala         |   2 +-
 .../python/MapInBatchEvaluatorFactory.scala        |   2 +-
 .../sql/execution/python/MapInBatchExec.scala      |   2 +-
 .../sql/execution/python/PythonArrowInput.scala    |   4 +-
 .../sql/execution/python/PythonArrowOutput.scala   |   6 +-
 .../python/PythonDataSourcePartitionsExec.scala    |  80 ---------
 .../python/UserDefinedPythonDataSource.scala       | 195 ++++++++++++++++++---
 .../spark/sql/streaming/DataStreamReader.scala     |   5 +-
 .../spark/sql/streaming/DataStreamWriter.scala     |   2 +-
 .../execution/python/PythonDataSourceSuite.scala   |  66 ++++---
 26 files changed, 290 insertions(+), 359 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index fb8b06eb41bc..f5930c5272a2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -17,13 +17,11 @@
 
 package org.apache.spark.sql.catalyst.plans.logical
 
-import org.apache.spark.api.python.PythonFunction
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, 
Expression, PythonUDF, PythonUDTF}
 import org.apache.spark.sql.catalyst.trees.TreePattern._
-import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
 import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
-import org.apache.spark.sql.types.{BinaryType, StructType}
+import org.apache.spark.sql.types.StructType
 
 /**
  * FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame.
@@ -103,42 +101,6 @@ case class PythonMapInArrow(
     copy(child = newChild)
 }
 
-/**
- * Represents a Python data source.
- */
-case class PythonDataSource(
-    dataSource: PythonFunction,
-    outputSchema: StructType,
-    override val output: Seq[Attribute]) extends LeafNode {
-  require(output.forall(_.resolved),
-    "Unresolved attributes found when constructing PythonDataSource.")
-  override protected def stringArgs: Iterator[Any] = {
-    Iterator(output)
-  }
-  final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_DATA_SOURCE)
-}
-
-/**
- * Represents a list of Python data source partitions.
- */
-case class PythonDataSourcePartitions(
-    output: Seq[Attribute],
-    partitions: Seq[Array[Byte]]) extends LeafNode {
-  override protected def stringArgs: Iterator[Any] = {
-    if (partitions.isEmpty) {
-      Iterator("<empty>", output)
-    } else {
-      Iterator(output)
-    }
-  }
-}
-
-object PythonDataSourcePartitions {
-  def schema: StructType = new StructType().add("partition", BinaryType)
-
-  def getOutputAttrs: Seq[Attribute] = toAttributes(schema)
-}
-
 /**
  * Flatmap cogroups using a udf: pandas.Dataframe, pandas.Dataframe -> 
pandas.Dataframe
  * This is used by DataFrame.groupby().cogroup().apply().
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index c29ffb329072..9992d8cbba07 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -17,12 +17,11 @@
 
 package org.apache.spark.sql
 
-import java.util.{Locale, Properties, ServiceConfigurationError}
+import java.util.{Locale, Properties}
 
 import scala.jdk.CollectionConverters._
-import scala.util.{Failure, Success, Try}
 
-import org.apache.spark.{Partition, SparkClassNotFoundException, 
SparkThrowable}
+import org.apache.spark.Partition
 import org.apache.spark.annotation.Stable
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.internal.Logging
@@ -209,45 +208,10 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
       throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError()
     }
 
-    val isUserDefinedDataSource =
-      sparkSession.sessionState.dataSourceManager.dataSourceExists(source)
-
-    Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) 
match {
-      case Success(providerOpt) =>
-        // The source can be successfully loaded as either a V1 or a V2 data 
source.
-        // Check if it is also a user-defined data source.
-        if (isUserDefinedDataSource) {
-          throw QueryCompilationErrors.foundMultipleDataSources(source)
-        }
-        providerOpt.flatMap { provider =>
-          DataSourceV2Utils.loadV2Source(
-            sparkSession, provider, userSpecifiedSchema, extraOptions, source, 
paths: _*)
-        }.getOrElse(loadV1Source(paths: _*))
-      case Failure(exception) =>
-        // Exceptions are thrown while trying to load the data source as a V1 
or V2 data source.
-        // For the following not found exceptions, if the user-defined data 
source is defined,
-        // we can instead return the user-defined data source.
-        val isNotFoundError = exception match {
-          case _: NoClassDefFoundError | _: SparkClassNotFoundException => true
-          case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND"
-          case e: ServiceConfigurationError => 
e.getCause.isInstanceOf[NoClassDefFoundError]
-          case _ => false
-        }
-        if (isNotFoundError && isUserDefinedDataSource) {
-          loadUserDefinedDataSource(paths)
-        } else {
-          // Throw the original exception.
-          throw exception
-        }
-    }
-  }
-
-  private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
-    val builder = 
sparkSession.sessionState.dataSourceManager.lookupDataSource(source)
-    // Add `path` and `paths` options to the extra options if specified.
-    val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, 
paths: _*)
-    val plan = builder(sparkSession, source, userSpecifiedSchema, 
optionsWithPath)
-    Dataset.ofRows(sparkSession, plan)
+    DataSource.lookupDataSourceV2(source, 
sparkSession.sessionState.conf).flatMap { provider =>
+      DataSourceV2Utils.loadV2Source(sparkSession, provider, 
userSpecifiedSchema, extraOptions,
+        source, paths: _*)
+    }.getOrElse(loadV1Source(paths: _*))
   }
 
   private def loadV1Source(paths: String*) = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
index 15d26418984b..936286eb0da5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
@@ -43,6 +43,6 @@ private[sql] class DataSourceRegistration private[sql] 
(dataSourceManager: DataS
          | pythonExec: ${dataSource.dataSourceCls.pythonExec}
       """.stripMargin)
 
-    dataSourceManager.registerDataSource(name, dataSource.builder)
+    dataSourceManager.registerDataSource(name, dataSource)
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 15eeca87dcf6..44a4d82c1dac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -780,7 +780,7 @@ class SparkSession private(
     DataSource.lookupDataSource(runner, sessionState.conf) match {
       case source if classOf[ExternalCommandRunner].isAssignableFrom(source) =>
         Dataset.ofRows(self, ExternalCommandExecutor(
-          source.getDeclaredConstructor().newInstance()
+          DataSource.newDataSourceInstance(runner, source)
             .asInstanceOf[ExternalCommandRunner], command, options))
 
       case _ =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 00328910f5b6..70a35ea91153 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.catalog.CatalogManager
-import org.apache.spark.sql.execution.datasources.{PlanPythonDataSourceScan, 
PruneFileSourcePartitions, SchemaPruning, V1Writes}
+import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, 
SchemaPruning, V1Writes}
 import 
org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning,
 OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, 
V2ScanRelationPushDown, V2Writes}
 import 
org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, 
PartitionPruning, RowLevelOperationRuntimeGroupFiltering}
 import 
org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, 
ExtractPythonUDFFromAggregate, ExtractPythonUDFs, ExtractPythonUDTFs}
@@ -42,8 +42,7 @@ class SparkOptimizer(
       V2ScanRelationPushDown :+
       V2ScanPartitioningAndOrdering :+
       V2Writes :+
-      PruneFileSourcePartitions :+
-      PlanPythonDataSourceScan
+      PruneFileSourcePartitions
 
   override def preCBORules: Seq[Rule[LogicalPlan]] =
     OptimizeMetadataOnlyDeleteFromTable :: Nil
@@ -102,8 +101,7 @@ class SparkOptimizer(
     V2ScanRelationPushDown.ruleName :+
     V2ScanPartitioningAndOrdering.ruleName :+
     V2Writes.ruleName :+
-    ReplaceCTERefWithRepartition.ruleName :+
-    PlanPythonDataSourceScan.ruleName
+    ReplaceCTERefWithRepartition.ruleName
 
   /**
    * Optimization batches that are executed before the regular optimization 
batches (also before
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 2d24f997d105..35070ac1d562 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -753,8 +753,6 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case ArrowEvalPythonUDTF(udtf, requiredChildOutput, resultAttrs, child, 
evalType) =>
         ArrowEvalPythonUDTFExec(
           udtf, requiredChildOutput, resultAttrs, planLater(child), evalType) 
:: Nil
-      case PythonDataSourcePartitions(output, partitions) =>
-        PythonDataSourcePartitionsExec(output, partitions) :: Nil
       case _ =>
         Nil
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index dc1c5b3fd580..199c8728a5c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -45,7 +45,7 @@ import 
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM
 import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import 
org.apache.spark.sql.errors.QueryExecutionErrors.hiveTableWithAnsiIntervalsError
-import org.apache.spark.sql.execution.datasources.{DataSource, 
DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation}
+import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
 import org.apache.spark.sql.types._
@@ -1025,7 +1025,9 @@ object DDLUtils extends Logging {
 
   def checkDataColNames(provider: String, schema: StructType): Unit = {
     val source = try {
-      DataSource.lookupDataSource(provider, 
SQLConf.get).getConstructor().newInstance()
+      DataSource.newDataSourceInstance(
+        provider,
+        DataSource.lookupDataSource(provider, SQLConf.get))
     } catch {
       case e: Throwable =>
         logError(s"Failed to find data source: $provider when check data 
column names.", e)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 2f8fca7cfd73..9771ee08b258 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -35,7 +35,7 @@ import 
org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, 
quoteIfNeeded, CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, 
ResolveDefaultColumns}
+import org.apache.spark.sql.catalyst.util._
 import 
org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY
 import 
org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
@@ -264,8 +264,9 @@ case class AlterTableAddColumnsCommand(
     }
 
     if (DDLUtils.isDatasourceTable(catalogTable)) {
-      DataSource.lookupDataSource(catalogTable.provider.get, conf).
-        getConstructor().newInstance() match {
+      DataSource.newDataSourceInstance(
+          catalogTable.provider.get,
+          DataSource.lookupDataSource(catalogTable.provider.get, conf)) match {
         // For datasource table, this command can only support the following 
File format.
         // TextFileFormat only default to one column "value"
         // Hive type is already considered as hive serde table, so the logic 
will not
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 71b6d4b886b4..9612d8ff24f5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -44,6 +44,7 @@ import 
org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
 import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
 import org.apache.spark.sql.execution.datasources.xml.XmlFileFormat
+import org.apache.spark.sql.execution.python.PythonTableProvider
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, 
TextSocketSourceProvider}
 import org.apache.spark.sql.internal.SQLConf
@@ -105,13 +106,14 @@ case class DataSource(
     // [[FileDataSourceV2]] will still be used if we call the load()/save() 
method in
     // [[DataFrameReader]]/[[DataFrameWriter]], since they use method 
`lookupDataSource`
     // instead of `providingClass`.
-    cls.getDeclaredConstructor().newInstance() match {
+    DataSource.newDataSourceInstance(className, cls) match {
       case f: FileDataSourceV2 => f.fallbackFileFormat
       case _ => cls
     }
   }
 
-  private[sql] def providingInstance(): Any = 
providingClass.getConstructor().newInstance()
+  private[sql] def providingInstance(): Any =
+    DataSource.newDataSourceInstance(className, providingClass)
 
   private def newHadoopConfiguration(): Configuration =
     sparkSession.sessionState.newHadoopConfWithOptions(options)
@@ -622,6 +624,15 @@ object DataSource extends Logging {
     "org.apache.spark.sql.sources.HadoopFsRelationProvider",
     "org.apache.spark.Logging")
 
+  /** Create the instance of the datasource */
+  def newDataSourceInstance(provider: String, providingClass: Class[_]): Any = 
{
+    providingClass match {
+      case cls if classOf[PythonTableProvider].isAssignableFrom(cls) =>
+        cls.getDeclaredConstructor(classOf[String]).newInstance(provider)
+      case cls => cls.getDeclaredConstructor().newInstance()
+    }
+  }
+
   /** Given a provider name, look up the data source class definition. */
   def lookupDataSource(provider: String, conf: SQLConf): Class[_] = {
     val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) 
match {
@@ -649,6 +660,9 @@ object DataSource extends Logging {
                 // Found the data source using fully qualified path
                 dataSource
               case Failure(error) =>
+                // TODO(SPARK-45600): should be session-based.
+                val isUserDefinedDataSource = 
SparkSession.getActiveSession.exists(
+                  _.sessionState.dataSourceManager.dataSourceExists(provider))
                 if (provider1.startsWith("org.apache.spark.sql.hive.orc")) {
                   throw QueryCompilationErrors.orcNotUsedWithHiveEnabledError()
                 } else if (provider1.toLowerCase(Locale.ROOT) == "avro" ||
@@ -657,6 +671,8 @@ object DataSource extends Logging {
                   throw 
QueryCompilationErrors.failedToFindAvroDataSourceError(provider1)
                 } else if (provider1.toLowerCase(Locale.ROOT) == "kafka") {
                   throw 
QueryCompilationErrors.failedToFindKafkaDataSourceError(provider1)
+                } else if (isUserDefinedDataSource) {
+                  classOf[PythonTableProvider]
                 } else {
                   throw 
QueryExecutionErrors.dataSourceNotFoundError(provider1, error)
                 }
@@ -673,6 +689,14 @@ object DataSource extends Logging {
           }
         case head :: Nil =>
           // there is exactly one registered alias
+          // TODO(SPARK-45600): should be session-based.
+          val isUserDefinedDataSource = SparkSession.getActiveSession.exists(
+            _.sessionState.dataSourceManager.dataSourceExists(provider))
+          // The source can be successfully loaded as either a V1 or a V2 data 
source.
+          // Check if it is also a user-defined data source.
+          if (isUserDefinedDataSource) {
+            throw QueryCompilationErrors.foundMultipleDataSources(provider)
+          }
           head.getClass
         case sources =>
           // There are multiple registered aliases for the input. If there is 
single datasource
@@ -708,9 +732,9 @@ object DataSource extends Logging {
   def lookupDataSourceV2(provider: String, conf: SQLConf): 
Option[TableProvider] = {
     val useV1Sources = 
conf.getConf(SQLConf.USE_V1_SOURCE_LIST).toLowerCase(Locale.ROOT)
       .split(",").map(_.trim)
-    val cls = lookupDataSource(provider, conf)
+    val providingClass = lookupDataSource(provider, conf)
     val instance = try {
-      cls.getDeclaredConstructor().newInstance()
+      newDataSourceInstance(provider, providingClass)
     } catch {
       // Throw the original error from the data source implementation.
       case e: java.lang.reflect.InvocationTargetException => throw e.getCause
@@ -718,7 +742,8 @@ object DataSource extends Logging {
     instance match {
       case d: DataSourceRegister if useV1Sources.contains(d.shortName()) => 
None
       case t: TableProvider
-          if 
!useV1Sources.contains(cls.getCanonicalName.toLowerCase(Locale.ROOT)) =>
+          if !useV1Sources.contains(
+            providingClass.getCanonicalName.toLowerCase(Locale.ROOT)) =>
         Some(t)
       case _ => None
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
index 1cdc3d9cb69e..e6c4749df60a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
@@ -21,26 +21,18 @@ import java.util.Locale
 import java.util.concurrent.ConcurrentHashMap
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
 import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource
+
 
 /**
  * A manager for user-defined data sources. It is used to register and lookup 
data sources by
  * their short names or fully qualified names.
  */
 class DataSourceManager extends Logging {
-
-  private type DataSourceBuilder = (
-    SparkSession,  // Spark session
-    String,  // provider name
-    Option[StructType],  // user specified schema
-    CaseInsensitiveMap[String]  // options
-  ) => LogicalPlan
-
-  private val dataSourceBuilders = new ConcurrentHashMap[String, 
DataSourceBuilder]()
+  // TODO(SPARK-45917): Statically load Python Data Source so idempotently 
Python
+  //   Data Sources can be loaded even when the Driver is restarted.
+  private val dataSourceBuilders = new ConcurrentHashMap[String, 
UserDefinedPythonDataSource]()
 
   private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)
 
@@ -48,9 +40,9 @@ class DataSourceManager extends Logging {
    * Register a data source builder for the given provider.
    * Note that the provider name is case-insensitive.
    */
-  def registerDataSource(name: String, builder: DataSourceBuilder): Unit = {
+  def registerDataSource(name: String, source: UserDefinedPythonDataSource): 
Unit = {
     val normalizedName = normalize(name)
-    val previousValue = dataSourceBuilders.put(normalizedName, builder)
+    val previousValue = dataSourceBuilders.put(normalizedName, source)
     if (previousValue != null) {
       logWarning(f"The data source $name replaced a previously registered data 
source.")
     }
@@ -60,7 +52,7 @@ class DataSourceManager extends Logging {
    * Returns a data source builder for the given provider and throw an 
exception if
    * it does not exist.
    */
-  def lookupDataSource(name: String): DataSourceBuilder = {
+  def lookupDataSource(name: String): UserDefinedPythonDataSource = {
     if (dataSourceExists(name)) {
       dataSourceBuilders.get(normalize(name))
     } else {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala
deleted file mode 100644
index 7ffd61a4a266..000000000000
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources
-
-import org.apache.spark.api.python.{PythonEvalType, PythonFunction, 
SimplePythonFunction}
-import org.apache.spark.sql.catalyst.expressions.PythonUDF
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, 
PythonDataSource, PythonDataSourcePartitions, PythonMapInArrow}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.PYTHON_DATA_SOURCE
-import 
org.apache.spark.sql.execution.python.UserDefinedPythonDataSourceReadRunner
-import org.apache.spark.util.ArrayImplicits._
-
-/**
- * A logical rule to plan reads from a Python data source.
- *
- * This rule creates a Python process and invokes the `DataSource.reader` 
method to create an
- * instance of the user-defined data source reader, generates partitions if 
any, and returns
- * the information back to JVM (this rule) to construct the logical plan for 
Python data source.
- *
- * For example, prior to applying this rule, the plan might look like:
- *
- *   PythonDataSource(dataSource, schema, output)
- *
- * Here, `dataSource` is a serialized Python function that contains an 
instance of the DataSource
- * class. Post this rule, the plan is transformed into:
- *
- *  Project [output]
- *  +- PythonMapInArrow [read_from_data_source, ...]
- *     +- PythonDataSourcePartitions [partition_bytes]
- *
- * The PythonDataSourcePartitions contains a list of serialized partition 
values for the data
- * source. The `DataSourceReader.read` method will be planned as a MapInArrow 
operator that
- * accepts a partition value and yields the scanning output.
- */
-object PlanPythonDataSourceScan extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
-    _.containsPattern(PYTHON_DATA_SOURCE)) {
-    case ds @ PythonDataSource(dataSource: PythonFunction, schema, _) =>
-      val inputSchema = PythonDataSourcePartitions.schema
-
-      val info = new UserDefinedPythonDataSourceReadRunner(
-        dataSource, inputSchema, schema).runInPython()
-
-      val readerFunc = SimplePythonFunction(
-        command = info.func.toImmutableArraySeq,
-        envVars = dataSource.envVars,
-        pythonIncludes = dataSource.pythonIncludes,
-        pythonExec = dataSource.pythonExec,
-        pythonVer = dataSource.pythonVer,
-        broadcastVars = dataSource.broadcastVars,
-        accumulator = dataSource.accumulator)
-
-      val partitionPlan = PythonDataSourcePartitions(
-        PythonDataSourcePartitions.getOutputAttrs, info.partitions)
-
-      val pythonUDF = PythonUDF(
-        name = "read_from_data_source",
-        func = readerFunc,
-        dataType = schema,
-        children = partitionPlan.output,
-        evalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
-        udfDeterministic = false)
-
-      // Construct the plan.
-      val plan = PythonMapInArrow(
-        pythonUDF,
-        ds.output,
-        partitionPlan,
-        isBarrier = false)
-
-      // Project out partition values.
-      Project(ds.output, plan)
-  }
-}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index cfe01f85cbe7..936ab866f5bf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -61,12 +61,14 @@ class ApplyInPandasWithStatePythonRunner(
     keySchema: StructType,
     outputSchema: StructType,
     stateValueSchema: StructType,
-    val pythonMetrics: Map[String, SQLMetric],
+    pyMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String])
   extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets, 
jobArtifactUUID)
   with PythonArrowInput[InType]
   with PythonArrowOutput[OutType] {
 
+  override val pythonMetrics: Option[Map[String, SQLMetric]] = Some(pyMetrics)
+
   override val pythonExec: String =
     SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
       funcs.head.funcs.head.pythonExec)
@@ -149,7 +151,7 @@ class ApplyInPandasWithStatePythonRunner(
 
       pandasWriter.finalizeGroup()
       val deltaData = dataOut.size() - startData
-      pythonMetrics("pythonDataSent") += deltaData
+      pythonMetrics.foreach(_("pythonDataSent") += deltaData)
       true
     } else {
       pandasWriter.finalizeData()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
index 9e210bf5241b..2503deae7d5a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
@@ -70,7 +70,7 @@ case class ArrowEvalPythonUDTFExec(
       sessionLocalTimeZone,
       largeVarTypes,
       pythonRunnerConf,
-      pythonMetrics,
+      Some(pythonMetrics),
       jobArtifactUUID).compute(batchIter, context.partitionId(), context)
 
     columnarBatchIter.map { batch =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index a9eaf79c9db0..5dcb79cc2b91 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -35,7 +35,7 @@ abstract class BaseArrowPythonRunner(
     _timeZoneId: String,
     protected override val largeVarTypes: Boolean,
     protected override val workerConf: Map[String, String],
-    val pythonMetrics: Map[String, SQLMetric],
+    override val pythonMetrics: Option[Map[String, SQLMetric]],
     jobArtifactUUID: Option[String])
   extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
     funcs, evalType, argOffsets, jobArtifactUUID)
@@ -74,7 +74,7 @@ class ArrowPythonRunner(
     _timeZoneId: String,
     largeVarTypes: Boolean,
     workerConf: Map[String, String],
-    pythonMetrics: Map[String, SQLMetric],
+    pythonMetrics: Option[Map[String, SQLMetric]],
     jobArtifactUUID: Option[String])
   extends BaseArrowPythonRunner(
     funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, 
workerConf,
@@ -100,7 +100,7 @@ class ArrowPythonWithNamedArgumentRunner(
     jobArtifactUUID: Option[String])
   extends BaseArrowPythonRunner(
     funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, 
largeVarTypes, workerConf,
-    pythonMetrics, jobArtifactUUID) {
+    Some(pythonMetrics), jobArtifactUUID) {
 
   override protected def writeUDF(dataOut: DataOutputStream): Unit =
     PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
index 87d1ccb25776..df2e89128124 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
@@ -39,7 +39,7 @@ class ArrowPythonUDTFRunner(
     protected override val timeZoneId: String,
     protected override val largeVarTypes: Boolean,
     protected override val workerConf: Map[String, String],
-    val pythonMetrics: Map[String, SQLMetric],
+    override val pythonMetrics: Option[Map[String, SQLMetric]],
     jobArtifactUUID: Option[String])
   extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
       Seq(ChainedPythonFunctions(Seq(udtf.func))), evalType, 
Array(argMetas.map(_.offset)),
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index eb56298bfbee..70bd1ce82e2e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -46,13 +46,15 @@ class CoGroupedArrowPythonRunner(
     rightSchema: StructType,
     timeZoneId: String,
     conf: Map[String, String],
-    val pythonMetrics: Map[String, SQLMetric],
+    pyMetrics: Map[String, SQLMetric],
     jobArtifactUUID: Option[String])
   extends BasePythonRunner[
     (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](
     funcs, evalType, argOffsets, jobArtifactUUID)
   with BasicPythonArrowOutput {
 
+  override val pythonMetrics: Option[Map[String, SQLMetric]] = Some(pyMetrics)
+
   override val pythonExec: String =
     SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
       funcs.head.funcs.head.pythonExec)
@@ -93,7 +95,7 @@ class CoGroupedArrowPythonRunner(
           writeGroup(nextRight, rightSchema, dataOut, "right")
 
           val deltaData = dataOut.size() - startData
-          pythonMetrics("pythonDataSent") += deltaData
+          pythonMetrics.foreach(_("pythonDataSent") += deltaData)
           true
         } else {
           dataOut.writeInt(0)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala
index 0c18206a825a..e5a00e2cc8ea 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala
@@ -88,7 +88,7 @@ trait FlatMapGroupsInPythonExec extends SparkPlan with 
UnaryExecNode with Python
         sessionLocalTimeZone,
         largeVarTypes,
         pythonRunnerConf,
-        pythonMetrics,
+        Some(pythonMetrics),
         jobArtifactUUID)
 
       executePython(data, output, runner)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
index 316c543ea807..00990ee46ea5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
@@ -36,7 +36,7 @@ class MapInBatchEvaluatorFactory(
     sessionLocalTimeZone: String,
     largeVarTypes: Boolean,
     pythonRunnerConf: Map[String, String],
-    pythonMetrics: Map[String, SQLMetric],
+    pythonMetrics: Option[Map[String, SQLMetric]],
     jobArtifactUUID: Option[String])
     extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index 8db389f02667..6db6c96b426a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -57,7 +57,7 @@ trait MapInBatchExec extends UnaryExecNode with 
PythonSQLMetrics {
       conf.sessionLocalTimeZone,
       conf.arrowUseLargeVarTypes,
       pythonRunnerConf,
-      pythonMetrics,
+      Some(pythonMetrics),
       jobArtifactUUID)
 
     if (isBarrier) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 1e075cab9224..6d0f31f35ff7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -46,7 +46,7 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
 
   protected val largeVarTypes: Boolean
 
-  protected def pythonMetrics: Map[String, SQLMetric]
+  protected def pythonMetrics: Option[Map[String, SQLMetric]]
 
   protected def writeNextInputToArrowStream(
       root: VectorSchemaRoot,
@@ -132,7 +132,7 @@ private[python] trait BasicPythonArrowInput extends 
PythonArrowInput[Iterator[In
       writer.writeBatch()
       arrowWriter.reset()
       val deltaData = dataOut.size() - startData
-      pythonMetrics("pythonDataSent") += deltaData
+      pythonMetrics.foreach(_("pythonDataSent") += deltaData)
       true
     } else {
       super[PythonArrowInput].close()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
index 90922d89ad10..82e8e7aa4f64 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, 
ColumnarBatch, Column
  */
 private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: 
BasePythonRunner[_, OUT] =>
 
-  protected def pythonMetrics: Map[String, SQLMetric]
+  protected def pythonMetrics: Option[Map[String, SQLMetric]]
 
   protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { }
 
@@ -91,8 +91,8 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { 
self: BasePythonRunner[
               val rowCount = root.getRowCount
               batch.setNumRows(root.getRowCount)
               val bytesReadEnd = reader.bytesRead()
-              pythonMetrics("pythonNumRowsReceived") += rowCount
-              pythonMetrics("pythonDataReceived") += bytesReadEnd - 
bytesReadStart
+              pythonMetrics.foreach(_("pythonNumRowsReceived") += rowCount)
+              pythonMetrics.foreach(_("pythonDataReceived") += bytesReadEnd - 
bytesReadStart)
               deserializeColumnarBatch(batch, schema)
             } else {
               reader.close(false)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonDataSourcePartitionsExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonDataSourcePartitionsExec.scala
deleted file mode 100644
index 8f1595cfdd71..000000000000
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonDataSourcePartitionsExec.scala
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.python
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.{InputRDDCodegen, LeafExecNode, 
SQLExecution}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.util.ArrayImplicits._
-
-/**
- * A physical plan node for scanning data from a list of data source partition 
values.
- *
- * It creates a RDD with number of partitions equal to size of the partition 
value list and
- * each partition contains a single row with a serialized partition value.
- */
-case class PythonDataSourcePartitionsExec(
-    output: Seq[Attribute],
-    partitions: Seq[Array[Byte]]) extends LeafExecNode with InputRDDCodegen {
-
-  override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output 
rows"))
-
-  @transient private lazy val unsafeRows: Array[InternalRow] = {
-    if (partitions.isEmpty) {
-      Array.empty
-    } else {
-      val proj = UnsafeProjection.create(output, output)
-      partitions.map(p => proj(InternalRow(p)).copy()).toArray
-    }
-  }
-
-  @transient private lazy val rdd: RDD[InternalRow] = {
-    val numPartitions = partitions.size
-    if (numPartitions == 0) {
-      sparkContext.emptyRDD
-    } else {
-      sparkContext.parallelize(unsafeRows.toImmutableArraySeq, numPartitions)
-    }
-  }
-
-  override def inputRDD: RDD[InternalRow] = rdd
-
-  override protected val createUnsafeProjection: Boolean = false
-
-  protected override def doExecute(): RDD[InternalRow] = {
-    longMetric("numOutputRows").add(partitions.size)
-    sendDriverMetrics()
-    rdd
-  }
-
-  override protected def stringArgs: Iterator[Any] = {
-    if (partitions.isEmpty) {
-      Iterator("<empty>", output)
-    } else {
-      Iterator(output)
-    }
-  }
-
-  private def sendDriverMetrics(): Unit = {
-    val executionId = 
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
-    SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, 
metrics.values.toSeq)
-  }
-}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index 2c8e1b942727..7c850d1e2890 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -20,58 +20,199 @@ package org.apache.spark.sql.execution.python
 import java.io.{DataInputStream, DataOutputStream}
 
 import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
 
 import net.razorvine.pickle.Pickler
 
-import org.apache.spark.api.python.{PythonFunction, PythonWorkerUtils, 
SimplePythonFunction, SpecialLengths}
-import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, 
PythonDataSource}
+import org.apache.spark.JobArtifactSet
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, 
PythonFunction, PythonWorkerUtils, SimplePythonFunction, SpecialLengths}
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.PythonUDF
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, 
TableCapability, TableProvider}
+import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.connector.read.{Batch, InputPartition, 
PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{BinaryType, DataType, StructType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
 import org.apache.spark.util.ArrayImplicits._
 
+/**
+ * Data Source V2 wrapper for Python Data Source.
+ */
+class PythonTableProvider(shortName: String) extends TableProvider {
+  private var dataSourceInPython: PythonDataSourceCreationResult = _
+  private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+  private lazy val source: UserDefinedPythonDataSource =
+    
SparkSession.active.sessionState.dataSourceManager.lookupDataSource(shortName)
+  override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
+    if (dataSourceInPython == null) {
+      dataSourceInPython = source.createDataSourceInPython(shortName, options, 
None)
+    }
+    dataSourceInPython.schema
+  }
+
+  override def getTable(
+      schema: StructType,
+      partitioning: Array[Transform],
+      properties: java.util.Map[String, String]): Table = {
+    val outputSchema = schema
+    new Table with SupportsRead {
+      override def name(): String = shortName
+
+      override def capabilities(): java.util.Set[TableCapability] = 
java.util.EnumSet.of(
+        BATCH_READ)
+
+      override def newScanBuilder(options: CaseInsensitiveStringMap): 
ScanBuilder = {
+        new ScanBuilder with Batch with Scan {
+
+          private lazy val infoInPython: PythonDataSourceReadInfo = {
+            if (dataSourceInPython == null) {
+              dataSourceInPython = source
+                .createDataSourceInPython(shortName, options, 
Some(outputSchema))
+            }
+            source.createReadInfoInPython(dataSourceInPython, outputSchema)
+          }
+
+          override def build(): Scan = this
+
+          override def toBatch: Batch = this
+
+          override def readSchema(): StructType = outputSchema
+
+          override def planInputPartitions(): Array[InputPartition] =
+            infoInPython.partitions.zipWithIndex.map(p => 
PythonInputPartition(p._2, p._1)).toArray
+
+          override def createReaderFactory(): PartitionReaderFactory = {
+            val readerFunc = infoInPython.func
+            new PythonPartitionReaderFactory(
+              source, readerFunc, outputSchema, jobArtifactUUID)
+          }
+        }
+      }
+
+      override def schema(): StructType = outputSchema
+    }
+  }
+
+  override def supportsExternalMetadata(): Boolean = true
+}
+
+case class PythonInputPartition(index: Int, pickedPartition: Array[Byte]) 
extends InputPartition
+
+class PythonPartitionReaderFactory(
+    source: UserDefinedPythonDataSource,
+    pickledReadFunc: Array[Byte],
+    outputSchema: StructType,
+    jobArtifactUUID: Option[String])
+  extends PartitionReaderFactory {
+
+  override def createReader(partition: InputPartition): 
PartitionReader[InternalRow] = {
+    new PartitionReader[InternalRow] {
+      private val outputIter = source.createPartitionReadIteratorInPython(
+        partition.asInstanceOf[PythonInputPartition],
+        pickledReadFunc,
+        outputSchema,
+        jobArtifactUUID)
+
+      override def next(): Boolean = outputIter.hasNext
+
+      override def get(): InternalRow = outputIter.next()
+
+      override def close(): Unit = {}
+    }
+  }
+}
+
 /**
  * A user-defined Python data source. This is used by the Python API.
+ * Defines the interation between Python and JVM.
  *
  * @param dataSourceCls The Python data source class.
  */
 case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
 
-  def builder(
-      sparkSession: SparkSession,
-      provider: String,
-      userSpecifiedSchema: Option[StructType],
-      options: CaseInsensitiveMap[String]): LogicalPlan = {
+  private val inputSchema: StructType = new StructType().add("partition", 
BinaryType)
+
+  /**
+   * (Driver-side) Run Python process, and get the pickled Python Data Source
+   * instance and its schema.
+   */
+  def createDataSourceInPython(
+      shortName: String,
+      options: CaseInsensitiveStringMap,
+      userSpecifiedSchema: Option[StructType]): PythonDataSourceCreationResult 
= {
+    new UserDefinedPythonDataSourceRunner(
+      dataSourceCls,
+      shortName,
+      userSpecifiedSchema,
+      
CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap)).runInPython()
+  }
 
-    val runner = new UserDefinedPythonDataSourceRunner(
-      dataSourceCls, provider, userSpecifiedSchema, options)
+  /**
+   * (Driver-side) Run Python process, and get the partition read functions, 
and
+   * partition information.
+   */
+  def createReadInfoInPython(
+      pythonResult: PythonDataSourceCreationResult,
+      outputSchema: StructType): PythonDataSourceReadInfo = {
+    new UserDefinedPythonDataSourceReadRunner(
+      createPythonFunction(
+        pythonResult.dataSource), inputSchema, outputSchema).runInPython()
+  }
 
-    val result = runner.runInPython()
-    val pickledDataSourceInstance = result.dataSource
+  /**
+   * (Executor-side) Create an iterator that reads the input partitions.
+   */
+  def createPartitionReadIteratorInPython(
+      partition: PythonInputPartition,
+      pickledReadFunc: Array[Byte],
+      outputSchema: StructType,
+      jobArtifactUUID: Option[String]): Iterator[InternalRow] = {
+    val readerFunc = createPythonFunction(pickledReadFunc)
+
+    val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF
+
+    val pythonUDF = PythonUDF(
+      name = "read_from_data_source",
+      func = readerFunc,
+      dataType = outputSchema,
+      children = toAttributes(inputSchema),
+      evalType = pythonEvalType,
+      udfDeterministic = false)
+
+    val conf = SQLConf.get
+
+    val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
+    val evaluatorFactory = new MapInBatchEvaluatorFactory(
+      toAttributes(outputSchema),
+      Seq(ChainedPythonFunctions(Seq(pythonUDF.func))),
+      inputSchema,
+      conf.arrowMaxRecordsPerBatch,
+      pythonEvalType,
+      conf.sessionLocalTimeZone,
+      conf.arrowUseLargeVarTypes,
+      pythonRunnerConf,
+      None,
+      jobArtifactUUID)
+
+    evaluatorFactory.createEvaluator().eval(
+      partition.index, Iterator.single(InternalRow(partition.pickedPartition)))
+  }
 
-    val dataSource = SimplePythonFunction(
-      command = pickledDataSourceInstance.toImmutableArraySeq,
+  private def createPythonFunction(pickledFunc: Array[Byte]): PythonFunction = 
{
+    SimplePythonFunction(
+      command = pickledFunc.toImmutableArraySeq,
       envVars = dataSourceCls.envVars,
       pythonIncludes = dataSourceCls.pythonIncludes,
       pythonExec = dataSourceCls.pythonExec,
       pythonVer = dataSourceCls.pythonVer,
       broadcastVars = dataSourceCls.broadcastVars,
       accumulator = dataSourceCls.accumulator)
-    val schema = result.schema
-
-    PythonDataSource(dataSource, schema, output = toAttributes(schema))
-  }
-
-  def apply(
-      sparkSession: SparkSession,
-      provider: String,
-      userSpecifiedSchema: Option[StructType] = None,
-      options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)): 
DataFrame = {
-    val plan = builder(sparkSession, provider, userSpecifiedSchema, options)
-    Dataset.ofRows(sparkSession, plan)
   }
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 1a69678c2f54..c93ca632d3c7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -156,8 +156,9 @@ final class DataStreamReader private[sql](sparkSession: 
SparkSession) extends Lo
       extraOptions + ("path" -> path.get)
     }
 
-    val ds = DataSource.lookupDataSource(source, 
sparkSession.sessionState.conf).
-      getConstructor().newInstance()
+    val ds = DataSource.newDataSourceInstance(
+      source,
+      DataSource.lookupDataSource(source, sparkSession.sessionState.conf))
     // We need to generate the V1 data source so we can pass it to the V2 
relation as a shim.
     // We can't be sure at this point whether we'll actually want to use V2, 
since we don't know the
     // writer or whether the query is continuous.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 95aa2f8c7a4e..7202f69ab1bf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -382,7 +382,7 @@ final class DataStreamWriter[T] private[sql](ds: 
Dataset[T]) {
       }
 
       val sink = if (classOf[TableProvider].isAssignableFrom(cls) && 
!useV1Source) {
-        val provider = 
cls.getConstructor().newInstance().asInstanceOf[TableProvider]
+        val provider = DataSource.newDataSourceInstance(source, 
cls).asInstanceOf[TableProvider]
         val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
           source = provider, conf = df.sparkSession.sessionState.conf)
         val finalOptions = sessionOptions.filter { case (k, _) => 
!optionsWithPath.contains(k) } ++
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index 6bc9166117f2..53a54abf8392 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -18,8 +18,7 @@
 package org.apache.spark.sql.execution.python
 
 import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, 
QueryTest, Row}
-import 
org.apache.spark.sql.catalyst.plans.logical.{PythonDataSourcePartitions, 
PythonMapInArrow}
-import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.StructType
 
@@ -53,12 +52,13 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
     val schema = StructType.fromDDL("id INT, partition INT")
     val dataSource = createUserDefinedPythonDataSource(
       name = dataSourceName, pythonScript = dataSourceScript)
-    val df = dataSource.apply(
-      spark, provider = dataSourceName, userSpecifiedSchema = Some(schema))
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val df = spark.read.format(dataSourceName).schema(schema).load()
     assert(df.rdd.getNumPartitions == 2)
     val plan = df.queryExecution.optimizedPlan
     plan match {
-      case PythonMapInArrow(_, _, _: PythonDataSourcePartitions, _) =>
+      case s: DataSourceV2ScanRelation
+        if s.relation.table.getClass.toString.contains("PythonTable") =>
       case _ => fail(s"Plan did not match the expected pattern. Actual 
plan:\n$plan")
     }
     checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), 
Row(2, 1)))
@@ -79,7 +79,8 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
          |        return SimpleDataSourceReader()
          |""".stripMargin
     val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
-    val df = dataSource(spark, provider = dataSourceName)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val df = spark.read.format(dataSourceName).load()
     checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), 
Row(2, 1)))
   }
 
@@ -102,7 +103,8 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
          |        return SimpleDataSourceReader()
          |""".stripMargin
     val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
-    val df = dataSource(spark, provider = dataSourceName)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val df = spark.read.format(dataSourceName).load()
     checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), 
Row(2, 1)))
   }
 
@@ -121,8 +123,9 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
          |        return SimpleDataSourceReader()
          |""".stripMargin
     val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
     checkError(
-      exception = intercept[AnalysisException](dataSource(spark, provider = 
dataSourceName)),
+      exception = 
intercept[AnalysisException](spark.read.format(dataSourceName).load()),
       errorClass = "INVALID_SCHEMA.NON_STRUCT_TYPE",
       parameters = Map("inputSchema" -> "INT", "dataType" -> "\"INT\""))
   }
@@ -145,9 +148,8 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
     val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
     spark.dataSource.registerPython(dataSourceName, dataSource)
     
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
-    val ds1 = 
spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName)
     checkAnswer(
-      ds1(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)),
+      spark.read.format(dataSourceName).load(),
       Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1)))
 
     // Should be able to override an already registered data source.
@@ -168,10 +170,8 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
     val newDataSource = createUserDefinedPythonDataSource(dataSourceName, 
newScript)
     spark.dataSource.registerPython(dataSourceName, newDataSource)
     
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
-
-    val ds2 = 
spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName)
     checkAnswer(
-      ds2(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)),
+      spark.read.format(dataSourceName).load(),
       Seq(Row(0)))
   }
 
@@ -195,12 +195,12 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
          |            paths = []
          |        return [InputPartition(p) for p in paths]
          |
-         |    def read(self, path):
-         |        if path is not None:
-         |            assert isinstance(path, InputPartition)
-         |            yield (path.value, 1)
+         |    def read(self, part):
+         |        if part is not None:
+         |            assert isinstance(part, InputPartition)
+         |            yield (part.value, 1)
          |        else:
-         |            yield (path, 1)
+         |            yield (part, 1)
          |
          |class $dataSourceName(DataSource):
          |    @classmethod
@@ -218,6 +218,12 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
     checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1)))
     checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1)))
     checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), 
Row("2", 1)))
+
+    withTable("tblA") {
+      sql("CREATE TABLE tblA USING test")
+      // The path will be the actual temp path.
+      checkAnswer(spark.table("tblA").selectExpr("value"), Seq(Row(1)))
+    }
   }
 
   test("reader not implemented") {
@@ -231,8 +237,9 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
     val schema = StructType.fromDDL("id INT, partition INT")
     val dataSource = createUserDefinedPythonDataSource(
       name = dataSourceName, pythonScript = dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
     val err = intercept[AnalysisException] {
-      dataSource(spark, dataSourceName, userSpecifiedSchema = 
Some(schema)).collect()
+      spark.read.format(dataSourceName).schema(schema).load().collect()
     }
     assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON")
     
assert(err.getMessage.contains("PYTHON_DATA_SOURCE_METHOD_NOT_IMPLEMENTED"))
@@ -250,8 +257,9 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
     val schema = StructType.fromDDL("id INT, partition INT")
     val dataSource = createUserDefinedPythonDataSource(
       name = dataSourceName, pythonScript = dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
     val err = intercept[AnalysisException] {
-      dataSource(spark, dataSourceName, userSpecifiedSchema = 
Some(schema)).collect()
+      spark.read.format(dataSourceName).schema(schema).load().collect()
     }
     assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON")
     assert(err.getMessage.contains("PYTHON_DATA_SOURCE_CREATE_ERROR"))
@@ -269,8 +277,9 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
     val schema = StructType.fromDDL("id INT, partition INT")
     val dataSource = createUserDefinedPythonDataSource(
       name = dataSourceName, pythonScript = dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
     val err = intercept[AnalysisException] {
-      dataSource(spark, dataSourceName, userSpecifiedSchema = 
Some(schema)).collect()
+      spark.read.format(dataSourceName).schema(schema).load().collect()
     }
     assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON")
     assert(err.getMessage.contains("PYTHON_DATA_SOURCE_TYPE_MISMATCH"))
@@ -278,7 +287,7 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
   }
 
   test("data source read with custom partitions") {
-    assume(shouldTestPythonUDFs)
+    assume(shouldTestPandasUDFs)
     val dataSourceScript =
       s"""
          |from pyspark.sql.datasource import DataSource, DataSourceReader, 
InputPartition
@@ -304,12 +313,13 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
          |        return SimpleDataSourceReader()
          |""".stripMargin
     val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
-    val df = dataSource(spark, provider = dataSourceName)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val df = spark.read.format(dataSourceName).load()
     checkAnswer(df, Seq(Row(1), Row(3)))
   }
 
   test("data source read with empty partitions") {
-    assume(shouldTestPythonUDFs)
+    assume(shouldTestPandasUDFs)
     val dataSourceScript =
       s"""
          |from pyspark.sql.datasource import DataSource, DataSourceReader
@@ -331,12 +341,13 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
          |        return SimpleDataSourceReader()
          |""".stripMargin
     val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
-    val df = dataSource(spark, provider = dataSourceName)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val df = spark.read.format(dataSourceName).load()
     checkAnswer(df, Row("success"))
   }
 
   test("data source read with invalid partitions") {
-    assume(shouldTestPythonUDFs)
+    assume(shouldTestPandasUDFs)
     val reader1 =
       s"""
          |class SimpleDataSourceReader(DataSourceReader):
@@ -378,8 +389,9 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
            |        return SimpleDataSourceReader()
            |""".stripMargin
       val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+      spark.dataSource.registerPython(dataSourceName, dataSource)
       val err = intercept[AnalysisException](
-        dataSource(spark, provider = dataSourceName).collect())
+        spark.read.format(dataSourceName).load().collect())
       assert(err.getErrorClass == 
"PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON")
       assert(err.getMessage.contains("PYTHON_DATA_SOURCE_CREATE_ERROR"))
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to