This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f3ffd3ca68d2 [SPARK-46720][SQL][PYTHON] Refactor Python Data Source to 
align with other DSv2 built-in Data Sources
f3ffd3ca68d2 is described below

commit f3ffd3ca68d27407a81e0406dbffe92b03a2d098
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Mon Jan 15 22:24:03 2024 +0900

    [SPARK-46720][SQL][PYTHON] Refactor Python Data Source to align with other 
DSv2 built-in Data Sources
    
    ### What changes were proposed in this pull request?
    
    This PR refactors Python Data Source to aline with other DSv2 built-in Data 
Sources such as CSV, Parquet, ORC, JDBC, etc.
    
    ### Why are the changes needed?
    
    For better readability, maintenance, and consistency.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing test cases should cover them.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44734 from HyukjinKwon/SPARK-46720.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/datasource.py                   |   5 +-
 .../apache/spark/sql/DataSourceRegistration.scala  |   2 +-
 .../sql/execution/datasources/DataSource.scala     |   6 +-
 .../execution/datasources/DataSourceManager.scala  |   2 +-
 .../v2/python/PythonBatchWriterFactory.scala       |  80 ++++++
 .../datasources/v2/python/PythonCustomMetric.scala |  51 ++++
 .../datasources/v2/python/PythonDataSourceV2.scala |  67 +++++
 .../v2/python/PythonPartitionReaderFactory.scala   |  66 +++++
 .../datasources/v2/python/PythonScan.scala         |  57 +++++
 .../datasources/v2/python/PythonScanBuilder.scala  |  30 +++
 .../datasources/v2/python/PythonTable.scala        |  46 ++++
 .../datasources/v2/python/PythonWrite.scala        |  65 +++++
 .../datasources/v2/python/PythonWriteBuilder.scala |  36 +++
 .../v2}/python/UserDefinedPythonDataSource.scala   | 275 +--------------------
 .../apache/spark/sql/IntegratedUDFTestUtils.scala  |   3 +-
 15 files changed, 520 insertions(+), 271 deletions(-)

diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index bdedbac3544e..1e50f8270243 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -407,7 +407,10 @@ class DataSourceRegistration:
         # Serialize the data source class.
         wrapped = _wrap_function(sc, dataSource)
         assert sc._jvm is not None
-        ds = 
sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonDataSource(wrapped)
+        jvm = sc._jvm
+        ds = 
jvm.org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource(
+            wrapped
+        )
         self.sparkSession._jsparkSession.dataSource().registerPython(name, ds)
 
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
index 2434103f4b80..63cee8861c5a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
@@ -22,7 +22,7 @@ import org.apache.spark.annotation.Evolving
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.datasources.{DataSource, 
DataSourceManager}
-import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource
+import 
org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource
 import org.apache.spark.sql.internal.SQLConf
 
 /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 5190075f652b..3e03fd652f18 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -43,8 +43,8 @@ import 
org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
 import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
 import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
 import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
+import org.apache.spark.sql.execution.datasources.v2.python.PythonDataSourceV2
 import org.apache.spark.sql.execution.datasources.xml.XmlFileFormat
-import org.apache.spark.sql.execution.python.PythonTableProvider
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, 
TextSocketSourceProvider}
 import org.apache.spark.sql.internal.SQLConf
@@ -661,7 +661,7 @@ object DataSource extends Logging {
                 } else if (provider1.toLowerCase(Locale.ROOT) == "kafka") {
                   throw 
QueryCompilationErrors.failedToFindKafkaDataSourceError(provider1)
                 } else if (isUserDefinedDataSource) {
-                  classOf[PythonTableProvider]
+                  classOf[PythonDataSourceV2]
                 } else {
                   throw 
QueryExecutionErrors.dataSourceNotFoundError(provider1, error)
                 }
@@ -734,7 +734,7 @@ object DataSource extends Logging {
       case t: TableProvider
           if 
!useV1Sources.contains(cls.getCanonicalName.toLowerCase(Locale.ROOT)) =>
         t match {
-          case p: PythonTableProvider => p.setShortName(provider)
+          case p: PythonDataSourceV2 => p.setShortName(provider)
           case _ =>
         }
         Some(t)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
index 53003989a338..28c93357d8b4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
@@ -24,7 +24,7 @@ import java.util.concurrent.ConcurrentHashMap
 import org.apache.spark.api.python.PythonUtils
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource
+import 
org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource
 import org.apache.spark.util.Utils
 
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonBatchWriterFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonBatchWriterFactory.scala
new file mode 100644
index 000000000000..d5412f1bdd38
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonBatchWriterFactory.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import scala.jdk.CollectionConverters.IteratorHasAsScala
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.metric.CustomTaskMetric
+import org.apache.spark.sql.connector.write._
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.types.StructType
+
+case class PythonWriterCommitMessage(pickledMessage: Array[Byte]) extends 
WriterCommitMessage
+
+case class PythonBatchWriterFactory(
+    source: UserDefinedPythonDataSource,
+    pickledWriteFunc: Array[Byte],
+    inputSchema: StructType,
+    jobArtifactUUID: Option[String]) extends DataWriterFactory {
+  override def createWriter(partitionId: Int, taskId: Long): 
DataWriter[InternalRow] = {
+    new DataWriter[InternalRow] {
+
+      private[this] val metrics: Map[String, SQLMetric] = 
PythonCustomMetric.pythonMetrics
+
+      private var commitMessage: PythonWriterCommitMessage = _
+
+      override def writeAll(records: java.util.Iterator[InternalRow]): Unit = {
+        val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
+          pickledWriteFunc,
+          "write_to_data_source",
+          inputSchema,
+          UserDefinedPythonDataSource.writeOutputSchema,
+          metrics,
+          jobArtifactUUID)
+        val outputIter = evaluatorFactory.createEvaluator().eval(partitionId, 
records.asScala)
+        outputIter.foreach { row =>
+          if (commitMessage == null) {
+            commitMessage = PythonWriterCommitMessage(row.getBinary(0))
+          } else {
+            throw QueryExecutionErrors.invalidWriterCommitMessageError(details 
= "more than one")
+          }
+        }
+        if (commitMessage == null) {
+          throw QueryExecutionErrors.invalidWriterCommitMessageError(details = 
"zero")
+        }
+      }
+
+      override def write(record: InternalRow): Unit =
+        SparkException.internalError("write method for Python data source 
should not be called.")
+
+      override def commit(): WriterCommitMessage = {
+        commitMessage.asInstanceOf[WriterCommitMessage]
+      }
+
+      override def abort(): Unit = {}
+
+      override def close(): Unit = {}
+
+      override def currentMetricsValues(): Array[CustomTaskMetric] = {
+        source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> 
v.value })
+      }
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala
new file mode 100644
index 000000000000..bca1cbed7e70
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.python.PythonSQLMetrics
+
+
+class PythonCustomMetric(
+    override val name: String,
+    override val description: String) extends CustomMetric {
+  // To allow the aggregation can be called. See 
`SQLAppStatusListener.aggregateMetrics`
+  def this() = this(null, null)
+
+  override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
+    SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long])
+  }
+}
+
+class PythonCustomTaskMetric(
+    override val name: String,
+    override val value: Long) extends CustomTaskMetric
+
+
+object PythonCustomMetric {
+  val pythonMetrics: Map[String, SQLMetric] = {
+    // Dummy SQLMetrics. The result is manually reported via DSv2 interface
+    // via passing the value to `CustomTaskMetric`. Note that 
`pythonOtherMetricsDesc`
+    // is not used when it is reported. It is to reuse existing Python runner.
+    // See also `UserDefinedPythonDataSource.createPythonMetrics`.
+    PythonSQLMetrics.pythonSizeMetricsDesc.keys
+      .map(_ -> new SQLMetric("size", -1)).toMap ++
+      PythonSQLMetrics.pythonOtherMetricsDesc.keys
+        .map(_ -> new SQLMetric("sum", -1)).toMap
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala
new file mode 100644
index 000000000000..edea70258779
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connector.catalog._
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+/**
+ * Data Source V2 wrapper for Python Data Source.
+ */
+class PythonDataSourceV2 extends TableProvider {
+  private var name: String = _
+
+  def setShortName(str: String): Unit = {
+    assert(name == null)
+    name = str
+  }
+
+  private def shortName: String = {
+    assert(name != null)
+    name
+  }
+
+  private var dataSourceInPython: PythonDataSourceCreationResult = _
+  private[python] lazy val source: UserDefinedPythonDataSource =
+    
SparkSession.active.sessionState.dataSourceManager.lookupDataSource(shortName)
+
+  def getOrCreateDataSourceInPython(
+      shortName: String,
+      options: CaseInsensitiveStringMap,
+      userSpecifiedSchema: Option[StructType]): PythonDataSourceCreationResult 
= {
+    if (dataSourceInPython == null) {
+      dataSourceInPython = source.createDataSourceInPython(shortName, options, 
userSpecifiedSchema)
+    }
+    dataSourceInPython
+  }
+
+  override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
+    getOrCreateDataSourceInPython(shortName, options, None).schema
+  }
+
+  override def getTable(
+      schema: StructType,
+      partitioning: Array[Transform],
+      properties: java.util.Map[String, String]): Table = {
+    new PythonTable(this, shortName, schema)
+  }
+
+  override def supportsExternalMetadata(): Boolean = true
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala
new file mode 100644
index 000000000000..44933779c26a
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.metric.CustomTaskMetric
+import org.apache.spark.sql.connector.read._
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.types.StructType
+
+
+case class PythonInputPartition(index: Int, pickedPartition: Array[Byte]) 
extends InputPartition
+
+class PythonPartitionReaderFactory(
+    source: UserDefinedPythonDataSource,
+    pickledReadFunc: Array[Byte],
+    outputSchema: StructType,
+    jobArtifactUUID: Option[String])
+  extends PartitionReaderFactory {
+
+  override def createReader(partition: InputPartition): 
PartitionReader[InternalRow] = {
+    new PartitionReader[InternalRow] {
+
+      private[this] val metrics: Map[String, SQLMetric] = 
PythonCustomMetric.pythonMetrics
+
+      private val outputIter = {
+        val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
+          pickledReadFunc,
+          "read_from_data_source",
+          UserDefinedPythonDataSource.readInputSchema,
+          outputSchema,
+          metrics,
+          jobArtifactUUID)
+
+        val part = partition.asInstanceOf[PythonInputPartition]
+        evaluatorFactory.createEvaluator().eval(
+          part.index, Iterator.single(InternalRow(part.pickedPartition)))
+      }
+
+      override def next(): Boolean = outputIter.hasNext
+
+      override def get(): InternalRow = outputIter.next()
+
+      override def close(): Unit = {}
+
+      override def currentMetricsValues(): Array[CustomTaskMetric] = {
+        source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> 
v.value})
+      }
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
new file mode 100644
index 000000000000..75cbe38b1397
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.JobArtifactSet
+import org.apache.spark.sql.connector.metric.CustomMetric
+import org.apache.spark.sql.connector.read._
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+
+class PythonScan(
+     ds: PythonDataSourceV2,
+     shortName: String,
+     outputSchema: StructType,
+     options: CaseInsensitiveStringMap) extends Batch with Scan {
+
+  private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+
+  private lazy val infoInPython: PythonDataSourceReadInfo = {
+    ds.source.createReadInfoInPython(
+      ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)),
+      outputSchema)
+  }
+
+  override def planInputPartitions(): Array[InputPartition] =
+    infoInPython.partitions.zipWithIndex.map(p => PythonInputPartition(p._2, 
p._1)).toArray
+
+  override def createReaderFactory(): PartitionReaderFactory = {
+    val readerFunc = infoInPython.func
+    new PythonPartitionReaderFactory(
+      ds.source, readerFunc, outputSchema, jobArtifactUUID)
+  }
+
+  override def toBatch: Batch = this
+
+  override def description: String = "(Python)"
+
+  override def readSchema(): StructType = outputSchema
+
+  override def supportedCustomMetrics(): Array[CustomMetric] =
+    ds.source.createPythonMetrics()
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala
new file mode 100644
index 000000000000..e30fc9f7978c
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+
+class PythonScanBuilder(
+    ds: PythonDataSourceV2,
+    shortName: String,
+    outputSchema: StructType,
+    options: CaseInsensitiveStringMap) extends ScanBuilder {
+  override def build(): Scan = new PythonScan(ds, shortName, outputSchema, 
options)
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
new file mode 100644
index 000000000000..6bea97795a35
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, 
Table, TableCapability}
+import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, 
BATCH_WRITE, TRUNCATE}
+import org.apache.spark.sql.connector.read.ScanBuilder
+import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+
+class PythonTable(
+    ds: PythonDataSourceV2,
+    shortName: String,
+    outputSchema: StructType
+  ) extends Table with SupportsRead with SupportsWrite {
+  override def name(): String = shortName
+
+  override def capabilities(): java.util.Set[TableCapability] = 
java.util.EnumSet.of(
+    BATCH_READ, BATCH_WRITE, TRUNCATE)
+
+  override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder 
= {
+    new PythonScanBuilder(ds, shortName, outputSchema, options)
+  }
+
+  override def schema(): StructType = outputSchema
+
+  override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+    new PythonWriteBuilder(ds, shortName, info)
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
new file mode 100644
index 000000000000..d216dfde9974
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.JobArtifactSet
+import org.apache.spark.sql.connector.metric.CustomMetric
+import org.apache.spark.sql.connector.write._
+
+
+class PythonWrite(
+    ds: PythonDataSourceV2,
+    shortName: String,
+    info: LogicalWriteInfo,
+    isTruncate: Boolean
+  ) extends Write with BatchWrite {
+  private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+
+  // Store the pickled data source writer instance.
+  private var pythonDataSourceWriter: Array[Byte] = _
+
+  override def createBatchWriterFactory(
+      physicalInfo: PhysicalWriteInfo): DataWriterFactory = {
+
+    val writeInfo = ds.source.createWriteInfoInPython(
+      shortName,
+      info.schema(),
+      info.options(),
+      isTruncate)
+
+    pythonDataSourceWriter = writeInfo.writer
+
+    PythonBatchWriterFactory(ds.source, writeInfo.func, info.schema(), 
jobArtifactUUID)
+  }
+
+  override def commit(messages: Array[WriterCommitMessage]): Unit = {
+    ds.source.commitWriteInPython(pythonDataSourceWriter, messages)
+  }
+
+  override def abort(messages: Array[WriterCommitMessage]): Unit = {
+    ds.source.commitWriteInPython(pythonDataSourceWriter, messages, abort = 
true)
+  }
+
+  override def toString: String = shortName
+
+  override def toBatch: BatchWrite = this
+
+  override def description: String = "(Python)"
+
+  override def supportedCustomMetrics(): Array[CustomMetric] =
+    ds.source.createPythonMetrics()
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWriteBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWriteBuilder.scala
new file mode 100644
index 000000000000..dffc184ffb87
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWriteBuilder.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.sql.connector.write._
+
+
+class PythonWriteBuilder(
+    ds: PythonDataSourceV2,
+    shortName: String,
+    info: LogicalWriteInfo
+  ) extends WriteBuilder with SupportsTruncate {
+
+  private var isTruncate = false
+
+  override def truncate(): WriteBuilder = {
+    isTruncate = true
+    this
+  }
+
+  override def build(): Write = new PythonWrite(ds, shortName, info, 
isTruncate)
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
similarity index 61%
rename from 
sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
rename to 
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index 4b567c591672..f11a63429d78 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.execution.python
+package org.apache.spark.sql.execution.datasources.v2.python
 
 import java.io.{DataInputStream, DataOutputStream}
 
@@ -24,273 +24,20 @@ import scala.jdk.CollectionConverters._
 
 import net.razorvine.pickle.Pickler
 
-import org.apache.spark.{JobArtifactSet, SparkException}
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, 
PythonFunction, PythonUtils, PythonWorkerUtils, SimplePythonFunction, 
SpecialLengths}
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.api.python._
 import org.apache.spark.sql.catalyst.expressions.PythonUDF
 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
-import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, 
Table, TableCapability, TableProvider}
-import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, 
BATCH_WRITE, TRUNCATE}
-import org.apache.spark.sql.connector.expressions.Transform
 import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
-import org.apache.spark.sql.connector.read.{Batch, InputPartition, 
PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
-import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, 
DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, SupportsTruncate, 
Write, WriteBuilder, WriterCommitMessage}
-import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
-import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.connector.write._
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.execution.python.{ArrowPythonRunner, 
MapInBatchEvaluatorFactory, PythonPlannerRunner, PythonSQLMetrics}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{BinaryType, DataType, StructType}
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 import org.apache.spark.util.ArrayImplicits._
 
-/**
- * Data Source V2 wrapper for Python Data Source.
- */
-class PythonTableProvider extends TableProvider {
-  private var name: String = _
-  def setShortName(str: String): Unit = {
-    assert(name == null)
-    name = str
-  }
-  private def shortName: String = {
-    assert(name != null)
-    name
-  }
-  private var dataSourceInPython: PythonDataSourceCreationResult = _
-  private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
-  private lazy val source: UserDefinedPythonDataSource =
-    
SparkSession.active.sessionState.dataSourceManager.lookupDataSource(shortName)
-  override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
-    if (dataSourceInPython == null) {
-      dataSourceInPython = source.createDataSourceInPython(shortName, options, 
None)
-    }
-    dataSourceInPython.schema
-  }
-
-  override def getTable(
-      schema: StructType,
-      partitioning: Array[Transform],
-      properties: java.util.Map[String, String]): Table = {
-    val outputSchema = schema
-    new Table with SupportsRead with SupportsWrite {
-      override def name(): String = shortName
-
-      override def capabilities(): java.util.Set[TableCapability] = 
java.util.EnumSet.of(
-        BATCH_READ, BATCH_WRITE, TRUNCATE)
-
-      override def newScanBuilder(options: CaseInsensitiveStringMap): 
ScanBuilder = {
-        new ScanBuilder with Batch with Scan {
-
-          private lazy val infoInPython: PythonDataSourceReadInfo = {
-            if (dataSourceInPython == null) {
-              dataSourceInPython = source
-                .createDataSourceInPython(shortName, options, 
Some(outputSchema))
-            }
-            source.createReadInfoInPython(dataSourceInPython, outputSchema)
-          }
-
-          override def build(): Scan = this
-
-          override def toBatch: Batch = this
-
-          override def readSchema(): StructType = outputSchema
-
-          override def planInputPartitions(): Array[InputPartition] =
-            infoInPython.partitions.zipWithIndex.map(p => 
PythonInputPartition(p._2, p._1)).toArray
-
-          override def createReaderFactory(): PartitionReaderFactory = {
-            val readerFunc = infoInPython.func
-            new PythonPartitionReaderFactory(
-              source, readerFunc, outputSchema, jobArtifactUUID)
-          }
-
-          override def description: String = "(Python)"
-
-          override def supportedCustomMetrics(): Array[CustomMetric] =
-            source.createPythonMetrics()
-        }
-      }
-
-      override def schema(): StructType = outputSchema
-
-      override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
-        new WriteBuilder with SupportsTruncate {
-
-          private var isTruncate = false
-
-          override def truncate(): WriteBuilder = {
-            isTruncate = true
-            this
-          }
-
-          override def build(): Write = new Write {
-
-            override def toBatch: BatchWrite = new BatchWrite {
-
-              // Store the pickled data source writer instance.
-              private var pythonDataSourceWriter: Array[Byte] = _
-
-              override def createBatchWriterFactory(
-                physicalInfo: PhysicalWriteInfo): DataWriterFactory = {
-
-                val writeInfo = source.createWriteInfoInPython(
-                  shortName,
-                  info.schema(),
-                  info.options(),
-                  isTruncate)
-
-                pythonDataSourceWriter = writeInfo.writer
-
-                PythonBatchWriterFactory(source, writeInfo.func, 
info.schema(), jobArtifactUUID)
-              }
-
-              override def commit(messages: Array[WriterCommitMessage]): Unit 
= {
-                source.commitWriteInPython(pythonDataSourceWriter, messages)
-              }
-
-              override def abort(messages: Array[WriterCommitMessage]): Unit = 
{
-                source.commitWriteInPython(pythonDataSourceWriter, messages, 
abort = true)
-              }
-
-              override def toString: String = shortName
-            }
-
-            override def description: String = "(Python)"
-
-            override def supportedCustomMetrics(): Array[CustomMetric] =
-              source.createPythonMetrics()
-          }
-        }
-      }
-    }
-  }
-
-  override def supportsExternalMetadata(): Boolean = true
-}
-
-case class PythonInputPartition(index: Int, pickedPartition: Array[Byte]) 
extends InputPartition
-
-class PythonPartitionReaderFactory(
-    source: UserDefinedPythonDataSource,
-    pickledReadFunc: Array[Byte],
-    outputSchema: StructType,
-    jobArtifactUUID: Option[String])
-  extends PartitionReaderFactory with PythonDataSourceSQLMetrics {
-
-  override def createReader(partition: InputPartition): 
PartitionReader[InternalRow] = {
-    new PartitionReader[InternalRow] {
-
-      private[this] val metrics: Map[String, SQLMetric] = pythonMetrics
-
-      private val outputIter = {
-        val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
-          pickledReadFunc,
-          "read_from_data_source",
-          UserDefinedPythonDataSource.readInputSchema,
-          outputSchema,
-          metrics,
-          jobArtifactUUID)
-
-        val part = partition.asInstanceOf[PythonInputPartition]
-        evaluatorFactory.createEvaluator().eval(
-          part.index, Iterator.single(InternalRow(part.pickedPartition)))
-      }
-
-      override def next(): Boolean = outputIter.hasNext
-
-      override def get(): InternalRow = outputIter.next()
-
-      override def close(): Unit = {}
-
-      override def currentMetricsValues(): Array[CustomTaskMetric] = {
-        source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> 
v.value})
-      }
-    }
-  }
-}
-
-case class PythonWriterCommitMessage(pickledMessage: Array[Byte]) extends 
WriterCommitMessage
-
-private case class PythonBatchWriterFactory(
-    source: UserDefinedPythonDataSource,
-    pickledWriteFunc: Array[Byte],
-    inputSchema: StructType,
-    jobArtifactUUID: Option[String]) extends DataWriterFactory with 
PythonDataSourceSQLMetrics {
-  override def createWriter(partitionId: Int, taskId: Long): 
DataWriter[InternalRow] = {
-    new DataWriter[InternalRow] {
-
-      private[this] val metrics: Map[String, SQLMetric] = pythonMetrics
-
-      private var commitMessage: PythonWriterCommitMessage = _
-
-      override def writeAll(records: java.util.Iterator[InternalRow]): Unit = {
-        val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
-          pickledWriteFunc,
-          "write_to_data_source",
-          inputSchema,
-          UserDefinedPythonDataSource.writeOutputSchema,
-          metrics,
-          jobArtifactUUID)
-        val outputIter = evaluatorFactory.createEvaluator().eval(partitionId, 
records.asScala)
-        outputIter.foreach { row =>
-          if (commitMessage == null) {
-            commitMessage = PythonWriterCommitMessage(row.getBinary(0))
-          } else {
-            throw QueryExecutionErrors.invalidWriterCommitMessageError(details 
= "more than one")
-          }
-        }
-        if (commitMessage == null) {
-          throw QueryExecutionErrors.invalidWriterCommitMessageError(details = 
"zero")
-        }
-      }
-
-      override def write(record: InternalRow): Unit =
-        SparkException.internalError("write method for Python data source 
should not be called.")
-
-      override def commit(): WriterCommitMessage = {
-        commitMessage.asInstanceOf[WriterCommitMessage]
-      }
-
-      override def abort(): Unit = {}
-
-      override def close(): Unit = {}
-
-      override def currentMetricsValues(): Array[CustomTaskMetric] = {
-        source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> 
v.value })
-      }
-    }
-  }
-}
-
-trait PythonDataSourceSQLMetrics {
-  // Dummy SQLMetrics. The result is manually reported via DSv2 interface
-  // via passing the value to `CustomTaskMetric`. Note that 
`pythonOtherMetricsDesc`
-  // is not used when it is reported. It is to reuse existing Python runner.
-  // See also `UserDefinedPythonDataSource.createPythonMetrics`.
-  protected lazy val pythonMetrics: Map[String, SQLMetric] = {
-    PythonSQLMetrics.pythonSizeMetricsDesc.keys
-      .map(_ -> new SQLMetric("size", -1)).toMap ++
-      PythonSQLMetrics.pythonOtherMetricsDesc.keys
-        .map(_ -> new SQLMetric("sum", -1)).toMap
-  }
-}
-
-class PythonCustomMetric(
-  override val name: String,
-  override val description: String) extends CustomMetric {
-  // To allow the aggregation can be called. See 
`SQLAppStatusListener.aggregateMetrics`
-  def this() = this(null, null)
-
-  override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
-    SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long])
-  }
-}
-
-class PythonCustomTaskMetric(
-    override val name: String,
-    override val value: Long) extends CustomTaskMetric
-
 /**
  * A user-defined Python data source. This is used by the Python API.
  * Defines the interation between Python and JVM.
@@ -444,7 +191,7 @@ case class PythonLookupAllDataSourcesResult(
 /**
  * A runner used to look up Python Data Sources available in Python path.
  */
-class UserDefinedPythonDataSourceLookupRunner(lookupSources: PythonFunction)
+private class UserDefinedPythonDataSourceLookupRunner(lookupSources: 
PythonFunction)
     extends 
PythonPlannerRunner[PythonLookupAllDataSourcesResult](lookupSources) {
 
   override val workerModule = "pyspark.sql.worker.lookup_data_sources"
@@ -490,7 +237,7 @@ case class PythonDataSourceCreationResult(
 /**
  * A runner used to create a Python data source in a Python process and return 
the result.
  */
-class UserDefinedPythonDataSourceRunner(
+private class UserDefinedPythonDataSourceRunner(
     dataSourceCls: PythonFunction,
     provider: String,
     userSpecifiedSchema: Option[StructType],
@@ -560,7 +307,7 @@ case class PythonDataSourceReadInfo(
  * @param inputSchema input schema to the data source read from its child plan
  * @param outputSchema output schema of the Python data source
  */
-class UserDefinedPythonDataSourceReadRunner(
+private class UserDefinedPythonDataSourceReadRunner(
     func: PythonFunction,
     inputSchema: StructType,
     outputSchema: StructType) extends 
PythonPlannerRunner[PythonDataSourceReadInfo](func) {
@@ -617,7 +364,7 @@ case class PythonDataSourceWriteInfo(func: Array[Byte], 
writer: Array[Byte])
  * A runner that creates a Python data source writer instance and returns a 
Python function
  * to be used to write data into the data source.
  */
-class UserDefinedPythonDataSourceWriteRunner(
+private class UserDefinedPythonDataSourceWriteRunner(
     dataSourceCls: PythonFunction,
     provider: String,
     inputSchema: StructType,
@@ -675,7 +422,7 @@ class UserDefinedPythonDataSourceWriteRunner(
  * A runner that takes a Python data source writer and a list of commit 
messages,
  * and invokes the `commit` or `abort` method of the writer in Python.
  */
-class UserDefinedPythonDataSourceCommitRunner(
+private class UserDefinedPythonDataSourceCommitRunner(
     dataSourceCls: PythonFunction,
     writer: Array[Byte],
     messages: Array[WriterCommitMessage],
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
index 1e691f71c511..62f3d7830ab0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
@@ -30,7 +30,8 @@ import org.apache.spark.api.python.{PythonBroadcast, 
PythonEvalType, PythonFunct
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, 
PythonUDF}
 import org.apache.spark.sql.catalyst.plans.SQLHelper
-import org.apache.spark.sql.execution.python.{UserDefinedPythonDataSource, 
UserDefinedPythonFunction, UserDefinedPythonTableFunction}
+import 
org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource
+import org.apache.spark.sql.execution.python.{UserDefinedPythonFunction, 
UserDefinedPythonTableFunction}
 import org.apache.spark.sql.expressions.SparkUserDefinedFunction
 import org.apache.spark.sql.types.{DataType, IntegerType, NullType, 
StringType, StructType}
 import org.apache.spark.util.ArrayImplicits._


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to