This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f3ffd3ca68d2 [SPARK-46720][SQL][PYTHON] Refactor Python Data Source to align with other DSv2 built-in Data Sources f3ffd3ca68d2 is described below commit f3ffd3ca68d27407a81e0406dbffe92b03a2d098 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Mon Jan 15 22:24:03 2024 +0900 [SPARK-46720][SQL][PYTHON] Refactor Python Data Source to align with other DSv2 built-in Data Sources ### What changes were proposed in this pull request? This PR refactors Python Data Source to aline with other DSv2 built-in Data Sources such as CSV, Parquet, ORC, JDBC, etc. ### Why are the changes needed? For better readability, maintenance, and consistency. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing test cases should cover them. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44734 from HyukjinKwon/SPARK-46720. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/datasource.py | 5 +- .../apache/spark/sql/DataSourceRegistration.scala | 2 +- .../sql/execution/datasources/DataSource.scala | 6 +- .../execution/datasources/DataSourceManager.scala | 2 +- .../v2/python/PythonBatchWriterFactory.scala | 80 ++++++ .../datasources/v2/python/PythonCustomMetric.scala | 51 ++++ .../datasources/v2/python/PythonDataSourceV2.scala | 67 +++++ .../v2/python/PythonPartitionReaderFactory.scala | 66 +++++ .../datasources/v2/python/PythonScan.scala | 57 +++++ .../datasources/v2/python/PythonScanBuilder.scala | 30 +++ .../datasources/v2/python/PythonTable.scala | 46 ++++ .../datasources/v2/python/PythonWrite.scala | 65 +++++ .../datasources/v2/python/PythonWriteBuilder.scala | 36 +++ .../v2}/python/UserDefinedPythonDataSource.scala | 275 +-------------------- .../apache/spark/sql/IntegratedUDFTestUtils.scala | 3 +- 15 files changed, 520 insertions(+), 271 deletions(-) diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py index bdedbac3544e..1e50f8270243 100644 --- a/python/pyspark/sql/datasource.py +++ b/python/pyspark/sql/datasource.py @@ -407,7 +407,10 @@ class DataSourceRegistration: # Serialize the data source class. wrapped = _wrap_function(sc, dataSource) assert sc._jvm is not None - ds = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonDataSource(wrapped) + jvm = sc._jvm + ds = jvm.org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource( + wrapped + ) self.sparkSession._jsparkSession.dataSource().registerPython(name, ds) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala index 2434103f4b80..63cee8861c5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala @@ -22,7 +22,7 @@ import org.apache.spark.annotation.Evolving import org.apache.spark.internal.Logging import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceManager} -import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource +import org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource import org.apache.spark.sql.internal.SQLConf /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 5190075f652b..3e03fd652f18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -43,8 +43,8 @@ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 +import org.apache.spark.sql.execution.datasources.v2.python.PythonDataSourceV2 import org.apache.spark.sql.execution.datasources.xml.XmlFileFormat -import org.apache.spark.sql.execution.python.PythonTableProvider import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} import org.apache.spark.sql.internal.SQLConf @@ -661,7 +661,7 @@ object DataSource extends Logging { } else if (provider1.toLowerCase(Locale.ROOT) == "kafka") { throw QueryCompilationErrors.failedToFindKafkaDataSourceError(provider1) } else if (isUserDefinedDataSource) { - classOf[PythonTableProvider] + classOf[PythonDataSourceV2] } else { throw QueryExecutionErrors.dataSourceNotFoundError(provider1, error) } @@ -734,7 +734,7 @@ object DataSource extends Logging { case t: TableProvider if !useV1Sources.contains(cls.getCanonicalName.toLowerCase(Locale.ROOT)) => t match { - case p: PythonTableProvider => p.setShortName(provider) + case p: PythonDataSourceV2 => p.setShortName(provider) case _ => } Some(t) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index 53003989a338..28c93357d8b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -24,7 +24,7 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.api.python.PythonUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource +import org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonBatchWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonBatchWriterFactory.scala new file mode 100644 index 000000000000..d5412f1bdd38 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonBatchWriterFactory.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.python + +import scala.jdk.CollectionConverters.IteratorHasAsScala + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.metric.CustomTaskMetric +import org.apache.spark.sql.connector.write._ +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.StructType + +case class PythonWriterCommitMessage(pickledMessage: Array[Byte]) extends WriterCommitMessage + +case class PythonBatchWriterFactory( + source: UserDefinedPythonDataSource, + pickledWriteFunc: Array[Byte], + inputSchema: StructType, + jobArtifactUUID: Option[String]) extends DataWriterFactory { + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + new DataWriter[InternalRow] { + + private[this] val metrics: Map[String, SQLMetric] = PythonCustomMetric.pythonMetrics + + private var commitMessage: PythonWriterCommitMessage = _ + + override def writeAll(records: java.util.Iterator[InternalRow]): Unit = { + val evaluatorFactory = source.createMapInBatchEvaluatorFactory( + pickledWriteFunc, + "write_to_data_source", + inputSchema, + UserDefinedPythonDataSource.writeOutputSchema, + metrics, + jobArtifactUUID) + val outputIter = evaluatorFactory.createEvaluator().eval(partitionId, records.asScala) + outputIter.foreach { row => + if (commitMessage == null) { + commitMessage = PythonWriterCommitMessage(row.getBinary(0)) + } else { + throw QueryExecutionErrors.invalidWriterCommitMessageError(details = "more than one") + } + } + if (commitMessage == null) { + throw QueryExecutionErrors.invalidWriterCommitMessageError(details = "zero") + } + } + + override def write(record: InternalRow): Unit = + SparkException.internalError("write method for Python data source should not be called.") + + override def commit(): WriterCommitMessage = { + commitMessage.asInstanceOf[WriterCommitMessage] + } + + override def abort(): Unit = {} + + override def close(): Unit = {} + + override def currentMetricsValues(): Array[CustomTaskMetric] = { + source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value }) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala new file mode 100644 index 000000000000..bca1cbed7e70 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonCustomMetric.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.python + +import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.python.PythonSQLMetrics + + +class PythonCustomMetric( + override val name: String, + override val description: String) extends CustomMetric { + // To allow the aggregation can be called. See `SQLAppStatusListener.aggregateMetrics` + def this() = this(null, null) + + override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { + SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long]) + } +} + +class PythonCustomTaskMetric( + override val name: String, + override val value: Long) extends CustomTaskMetric + + +object PythonCustomMetric { + val pythonMetrics: Map[String, SQLMetric] = { + // Dummy SQLMetrics. The result is manually reported via DSv2 interface + // via passing the value to `CustomTaskMetric`. Note that `pythonOtherMetricsDesc` + // is not used when it is reported. It is to reuse existing Python runner. + // See also `UserDefinedPythonDataSource.createPythonMetrics`. + PythonSQLMetrics.pythonSizeMetricsDesc.keys + .map(_ -> new SQLMetric("size", -1)).toMap ++ + PythonSQLMetrics.pythonOtherMetricsDesc.keys + .map(_ -> new SQLMetric("sum", -1)).toMap + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala new file mode 100644 index 000000000000..edea70258779 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonDataSourceV2.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.python + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.catalog._ +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Data Source V2 wrapper for Python Data Source. + */ +class PythonDataSourceV2 extends TableProvider { + private var name: String = _ + + def setShortName(str: String): Unit = { + assert(name == null) + name = str + } + + private def shortName: String = { + assert(name != null) + name + } + + private var dataSourceInPython: PythonDataSourceCreationResult = _ + private[python] lazy val source: UserDefinedPythonDataSource = + SparkSession.active.sessionState.dataSourceManager.lookupDataSource(shortName) + + def getOrCreateDataSourceInPython( + shortName: String, + options: CaseInsensitiveStringMap, + userSpecifiedSchema: Option[StructType]): PythonDataSourceCreationResult = { + if (dataSourceInPython == null) { + dataSourceInPython = source.createDataSourceInPython(shortName, options, userSpecifiedSchema) + } + dataSourceInPython + } + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + getOrCreateDataSourceInPython(shortName, options, None).schema + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: java.util.Map[String, String]): Table = { + new PythonTable(this, shortName, schema) + } + + override def supportsExternalMetadata(): Boolean = true +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala new file mode 100644 index 000000000000..44933779c26a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2.python + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.metric.CustomTaskMetric +import org.apache.spark.sql.connector.read._ +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.StructType + + +case class PythonInputPartition(index: Int, pickedPartition: Array[Byte]) extends InputPartition + +class PythonPartitionReaderFactory( + source: UserDefinedPythonDataSource, + pickledReadFunc: Array[Byte], + outputSchema: StructType, + jobArtifactUUID: Option[String]) + extends PartitionReaderFactory { + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + new PartitionReader[InternalRow] { + + private[this] val metrics: Map[String, SQLMetric] = PythonCustomMetric.pythonMetrics + + private val outputIter = { + val evaluatorFactory = source.createMapInBatchEvaluatorFactory( + pickledReadFunc, + "read_from_data_source", + UserDefinedPythonDataSource.readInputSchema, + outputSchema, + metrics, + jobArtifactUUID) + + val part = partition.asInstanceOf[PythonInputPartition] + evaluatorFactory.createEvaluator().eval( + part.index, Iterator.single(InternalRow(part.pickedPartition))) + } + + override def next(): Boolean = outputIter.hasNext + + override def get(): InternalRow = outputIter.next() + + override def close(): Unit = {} + + override def currentMetricsValues(): Array[CustomTaskMetric] = { + source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value}) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala new file mode 100644 index 000000000000..75cbe38b1397 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.python + +import org.apache.spark.JobArtifactSet +import org.apache.spark.sql.connector.metric.CustomMetric +import org.apache.spark.sql.connector.read._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + + +class PythonScan( + ds: PythonDataSourceV2, + shortName: String, + outputSchema: StructType, + options: CaseInsensitiveStringMap) extends Batch with Scan { + + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + private lazy val infoInPython: PythonDataSourceReadInfo = { + ds.source.createReadInfoInPython( + ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)), + outputSchema) + } + + override def planInputPartitions(): Array[InputPartition] = + infoInPython.partitions.zipWithIndex.map(p => PythonInputPartition(p._2, p._1)).toArray + + override def createReaderFactory(): PartitionReaderFactory = { + val readerFunc = infoInPython.func + new PythonPartitionReaderFactory( + ds.source, readerFunc, outputSchema, jobArtifactUUID) + } + + override def toBatch: Batch = this + + override def description: String = "(Python)" + + override def readSchema(): StructType = outputSchema + + override def supportedCustomMetrics(): Array[CustomMetric] = + ds.source.createPythonMetrics() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala new file mode 100644 index 000000000000..e30fc9f7978c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.python + +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + + +class PythonScanBuilder( + ds: PythonDataSourceV2, + shortName: String, + outputSchema: StructType, + options: CaseInsensitiveStringMap) extends ScanBuilder { + override def build(): Scan = new PythonScan(ds, shortName, outputSchema, options) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala new file mode 100644 index 000000000000..6bea97795a35 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.python + +import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE, TRUNCATE} +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + + +class PythonTable( + ds: PythonDataSourceV2, + shortName: String, + outputSchema: StructType + ) extends Table with SupportsRead with SupportsWrite { + override def name(): String = shortName + + override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of( + BATCH_READ, BATCH_WRITE, TRUNCATE) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new PythonScanBuilder(ds, shortName, outputSchema, options) + } + + override def schema(): StructType = outputSchema + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + new PythonWriteBuilder(ds, shortName, info) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala new file mode 100644 index 000000000000..d216dfde9974 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.python + +import org.apache.spark.JobArtifactSet +import org.apache.spark.sql.connector.metric.CustomMetric +import org.apache.spark.sql.connector.write._ + + +class PythonWrite( + ds: PythonDataSourceV2, + shortName: String, + info: LogicalWriteInfo, + isTruncate: Boolean + ) extends Write with BatchWrite { + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + // Store the pickled data source writer instance. + private var pythonDataSourceWriter: Array[Byte] = _ + + override def createBatchWriterFactory( + physicalInfo: PhysicalWriteInfo): DataWriterFactory = { + + val writeInfo = ds.source.createWriteInfoInPython( + shortName, + info.schema(), + info.options(), + isTruncate) + + pythonDataSourceWriter = writeInfo.writer + + PythonBatchWriterFactory(ds.source, writeInfo.func, info.schema(), jobArtifactUUID) + } + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + ds.source.commitWriteInPython(pythonDataSourceWriter, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + ds.source.commitWriteInPython(pythonDataSourceWriter, messages, abort = true) + } + + override def toString: String = shortName + + override def toBatch: BatchWrite = this + + override def description: String = "(Python)" + + override def supportedCustomMetrics(): Array[CustomMetric] = + ds.source.createPythonMetrics() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWriteBuilder.scala new file mode 100644 index 000000000000..dffc184ffb87 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWriteBuilder.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.python + +import org.apache.spark.sql.connector.write._ + + +class PythonWriteBuilder( + ds: PythonDataSourceV2, + shortName: String, + info: LogicalWriteInfo + ) extends WriteBuilder with SupportsTruncate { + + private var isTruncate = false + + override def truncate(): WriteBuilder = { + isTruncate = true + this + } + + override def build(): Write = new PythonWrite(ds, shortName, info, isTruncate) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala similarity index 61% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala index 4b567c591672..f11a63429d78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.python +package org.apache.spark.sql.execution.datasources.v2.python import java.io.{DataInputStream, DataOutputStream} @@ -24,273 +24,20 @@ import scala.jdk.CollectionConverters._ import net.razorvine.pickle.Pickler -import org.apache.spark.{JobArtifactSet, SparkException} -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonFunction, PythonUtils, PythonWorkerUtils, SimplePythonFunction, SpecialLengths} -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.expressions.PythonUDF import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability, TableProvider} -import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE, TRUNCATE} -import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} -import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} -import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, SupportsTruncate, Write, WriteBuilder, WriterCommitMessage} -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.connector.write._ +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, MapInBatchEvaluatorFactory, PythonPlannerRunner, PythonSQLMetrics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BinaryType, DataType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ -/** - * Data Source V2 wrapper for Python Data Source. - */ -class PythonTableProvider extends TableProvider { - private var name: String = _ - def setShortName(str: String): Unit = { - assert(name == null) - name = str - } - private def shortName: String = { - assert(name != null) - name - } - private var dataSourceInPython: PythonDataSourceCreationResult = _ - private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) - private lazy val source: UserDefinedPythonDataSource = - SparkSession.active.sessionState.dataSourceManager.lookupDataSource(shortName) - override def inferSchema(options: CaseInsensitiveStringMap): StructType = { - if (dataSourceInPython == null) { - dataSourceInPython = source.createDataSourceInPython(shortName, options, None) - } - dataSourceInPython.schema - } - - override def getTable( - schema: StructType, - partitioning: Array[Transform], - properties: java.util.Map[String, String]): Table = { - val outputSchema = schema - new Table with SupportsRead with SupportsWrite { - override def name(): String = shortName - - override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of( - BATCH_READ, BATCH_WRITE, TRUNCATE) - - override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new ScanBuilder with Batch with Scan { - - private lazy val infoInPython: PythonDataSourceReadInfo = { - if (dataSourceInPython == null) { - dataSourceInPython = source - .createDataSourceInPython(shortName, options, Some(outputSchema)) - } - source.createReadInfoInPython(dataSourceInPython, outputSchema) - } - - override def build(): Scan = this - - override def toBatch: Batch = this - - override def readSchema(): StructType = outputSchema - - override def planInputPartitions(): Array[InputPartition] = - infoInPython.partitions.zipWithIndex.map(p => PythonInputPartition(p._2, p._1)).toArray - - override def createReaderFactory(): PartitionReaderFactory = { - val readerFunc = infoInPython.func - new PythonPartitionReaderFactory( - source, readerFunc, outputSchema, jobArtifactUUID) - } - - override def description: String = "(Python)" - - override def supportedCustomMetrics(): Array[CustomMetric] = - source.createPythonMetrics() - } - } - - override def schema(): StructType = outputSchema - - override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - new WriteBuilder with SupportsTruncate { - - private var isTruncate = false - - override def truncate(): WriteBuilder = { - isTruncate = true - this - } - - override def build(): Write = new Write { - - override def toBatch: BatchWrite = new BatchWrite { - - // Store the pickled data source writer instance. - private var pythonDataSourceWriter: Array[Byte] = _ - - override def createBatchWriterFactory( - physicalInfo: PhysicalWriteInfo): DataWriterFactory = { - - val writeInfo = source.createWriteInfoInPython( - shortName, - info.schema(), - info.options(), - isTruncate) - - pythonDataSourceWriter = writeInfo.writer - - PythonBatchWriterFactory(source, writeInfo.func, info.schema(), jobArtifactUUID) - } - - override def commit(messages: Array[WriterCommitMessage]): Unit = { - source.commitWriteInPython(pythonDataSourceWriter, messages) - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = { - source.commitWriteInPython(pythonDataSourceWriter, messages, abort = true) - } - - override def toString: String = shortName - } - - override def description: String = "(Python)" - - override def supportedCustomMetrics(): Array[CustomMetric] = - source.createPythonMetrics() - } - } - } - } - } - - override def supportsExternalMetadata(): Boolean = true -} - -case class PythonInputPartition(index: Int, pickedPartition: Array[Byte]) extends InputPartition - -class PythonPartitionReaderFactory( - source: UserDefinedPythonDataSource, - pickledReadFunc: Array[Byte], - outputSchema: StructType, - jobArtifactUUID: Option[String]) - extends PartitionReaderFactory with PythonDataSourceSQLMetrics { - - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - new PartitionReader[InternalRow] { - - private[this] val metrics: Map[String, SQLMetric] = pythonMetrics - - private val outputIter = { - val evaluatorFactory = source.createMapInBatchEvaluatorFactory( - pickledReadFunc, - "read_from_data_source", - UserDefinedPythonDataSource.readInputSchema, - outputSchema, - metrics, - jobArtifactUUID) - - val part = partition.asInstanceOf[PythonInputPartition] - evaluatorFactory.createEvaluator().eval( - part.index, Iterator.single(InternalRow(part.pickedPartition))) - } - - override def next(): Boolean = outputIter.hasNext - - override def get(): InternalRow = outputIter.next() - - override def close(): Unit = {} - - override def currentMetricsValues(): Array[CustomTaskMetric] = { - source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value}) - } - } - } -} - -case class PythonWriterCommitMessage(pickledMessage: Array[Byte]) extends WriterCommitMessage - -private case class PythonBatchWriterFactory( - source: UserDefinedPythonDataSource, - pickledWriteFunc: Array[Byte], - inputSchema: StructType, - jobArtifactUUID: Option[String]) extends DataWriterFactory with PythonDataSourceSQLMetrics { - override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { - new DataWriter[InternalRow] { - - private[this] val metrics: Map[String, SQLMetric] = pythonMetrics - - private var commitMessage: PythonWriterCommitMessage = _ - - override def writeAll(records: java.util.Iterator[InternalRow]): Unit = { - val evaluatorFactory = source.createMapInBatchEvaluatorFactory( - pickledWriteFunc, - "write_to_data_source", - inputSchema, - UserDefinedPythonDataSource.writeOutputSchema, - metrics, - jobArtifactUUID) - val outputIter = evaluatorFactory.createEvaluator().eval(partitionId, records.asScala) - outputIter.foreach { row => - if (commitMessage == null) { - commitMessage = PythonWriterCommitMessage(row.getBinary(0)) - } else { - throw QueryExecutionErrors.invalidWriterCommitMessageError(details = "more than one") - } - } - if (commitMessage == null) { - throw QueryExecutionErrors.invalidWriterCommitMessageError(details = "zero") - } - } - - override def write(record: InternalRow): Unit = - SparkException.internalError("write method for Python data source should not be called.") - - override def commit(): WriterCommitMessage = { - commitMessage.asInstanceOf[WriterCommitMessage] - } - - override def abort(): Unit = {} - - override def close(): Unit = {} - - override def currentMetricsValues(): Array[CustomTaskMetric] = { - source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value }) - } - } - } -} - -trait PythonDataSourceSQLMetrics { - // Dummy SQLMetrics. The result is manually reported via DSv2 interface - // via passing the value to `CustomTaskMetric`. Note that `pythonOtherMetricsDesc` - // is not used when it is reported. It is to reuse existing Python runner. - // See also `UserDefinedPythonDataSource.createPythonMetrics`. - protected lazy val pythonMetrics: Map[String, SQLMetric] = { - PythonSQLMetrics.pythonSizeMetricsDesc.keys - .map(_ -> new SQLMetric("size", -1)).toMap ++ - PythonSQLMetrics.pythonOtherMetricsDesc.keys - .map(_ -> new SQLMetric("sum", -1)).toMap - } -} - -class PythonCustomMetric( - override val name: String, - override val description: String) extends CustomMetric { - // To allow the aggregation can be called. See `SQLAppStatusListener.aggregateMetrics` - def this() = this(null, null) - - override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { - SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long]) - } -} - -class PythonCustomTaskMetric( - override val name: String, - override val value: Long) extends CustomTaskMetric - /** * A user-defined Python data source. This is used by the Python API. * Defines the interation between Python and JVM. @@ -444,7 +191,7 @@ case class PythonLookupAllDataSourcesResult( /** * A runner used to look up Python Data Sources available in Python path. */ -class UserDefinedPythonDataSourceLookupRunner(lookupSources: PythonFunction) +private class UserDefinedPythonDataSourceLookupRunner(lookupSources: PythonFunction) extends PythonPlannerRunner[PythonLookupAllDataSourcesResult](lookupSources) { override val workerModule = "pyspark.sql.worker.lookup_data_sources" @@ -490,7 +237,7 @@ case class PythonDataSourceCreationResult( /** * A runner used to create a Python data source in a Python process and return the result. */ -class UserDefinedPythonDataSourceRunner( +private class UserDefinedPythonDataSourceRunner( dataSourceCls: PythonFunction, provider: String, userSpecifiedSchema: Option[StructType], @@ -560,7 +307,7 @@ case class PythonDataSourceReadInfo( * @param inputSchema input schema to the data source read from its child plan * @param outputSchema output schema of the Python data source */ -class UserDefinedPythonDataSourceReadRunner( +private class UserDefinedPythonDataSourceReadRunner( func: PythonFunction, inputSchema: StructType, outputSchema: StructType) extends PythonPlannerRunner[PythonDataSourceReadInfo](func) { @@ -617,7 +364,7 @@ case class PythonDataSourceWriteInfo(func: Array[Byte], writer: Array[Byte]) * A runner that creates a Python data source writer instance and returns a Python function * to be used to write data into the data source. */ -class UserDefinedPythonDataSourceWriteRunner( +private class UserDefinedPythonDataSourceWriteRunner( dataSourceCls: PythonFunction, provider: String, inputSchema: StructType, @@ -675,7 +422,7 @@ class UserDefinedPythonDataSourceWriteRunner( * A runner that takes a Python data source writer and a list of commit messages, * and invokes the `commit` or `abort` method of the writer in Python. */ -class UserDefinedPythonDataSourceCommitRunner( +private class UserDefinedPythonDataSourceCommitRunner( dataSourceCls: PythonFunction, writer: Array[Byte], messages: Array[WriterCommitMessage], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 1e691f71c511..62f3d7830ab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -30,7 +30,8 @@ import org.apache.spark.api.python.{PythonBroadcast, PythonEvalType, PythonFunct import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, PythonUDF} import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.execution.python.{UserDefinedPythonDataSource, UserDefinedPythonFunction, UserDefinedPythonTableFunction} +import org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource +import org.apache.spark.sql.execution.python.{UserDefinedPythonFunction, UserDefinedPythonTableFunction} import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType, StructType} import org.apache.spark.util.ArrayImplicits._ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org