This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 62fc27d79d5c [SPARK-46423][PYTHON][SQL] Make the Python Data Source 
instance at DataSource.lookupDataSourceV2
62fc27d79d5c is described below

commit 62fc27d79d5ccce476671a7a664272c718024617
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Fri Dec 15 15:18:20 2023 -0800

    [SPARK-46423][PYTHON][SQL] Make the Python Data Source instance at 
DataSource.lookupDataSourceV2
    
    ### What changes were proposed in this pull request?
    
    This PR is a kind of a followup of 
https://github.com/apache/spark/pull/44305 that proposes to create Python Data 
Source instance at `DataSource.lookupDataSourceV2`
    
    ### Why are the changes needed?
    
    Semantically the instance has to be ready at 
`DataSource.lookupDataSourceV2` level instead of after that. It's more 
consistent as well.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests should cover.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #44374 from HyukjinKwon/SPARK-46423.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  |  2 +-
 .../apache/spark/sql/execution/command/ddl.scala   |  6 ++----
 .../spark/sql/execution/command/tables.scala       |  7 +++---
 .../sql/execution/datasources/DataSource.scala     | 25 ++++++++--------------
 .../python/UserDefinedPythonDataSource.scala       | 11 +++++++++-
 .../spark/sql/streaming/DataStreamReader.scala     |  5 ++---
 .../spark/sql/streaming/DataStreamWriter.scala     |  2 +-
 7 files changed, 28 insertions(+), 30 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 44a4d82c1dac..15eeca87dcf6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -780,7 +780,7 @@ class SparkSession private(
     DataSource.lookupDataSource(runner, sessionState.conf) match {
       case source if classOf[ExternalCommandRunner].isAssignableFrom(source) =>
         Dataset.ofRows(self, ExternalCommandExecutor(
-          DataSource.newDataSourceInstance(runner, source)
+          source.getDeclaredConstructor().newInstance()
             .asInstanceOf[ExternalCommandRunner], command, options))
 
       case _ =>
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 199c8728a5c9..dc1c5b3fd580 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -45,7 +45,7 @@ import 
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM
 import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import 
org.apache.spark.sql.errors.QueryExecutionErrors.hiveTableWithAnsiIntervalsError
-import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.datasources.{DataSource, 
DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation}
 import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
 import org.apache.spark.sql.types._
@@ -1025,9 +1025,7 @@ object DDLUtils extends Logging {
 
   def checkDataColNames(provider: String, schema: StructType): Unit = {
     val source = try {
-      DataSource.newDataSourceInstance(
-        provider,
-        DataSource.lookupDataSource(provider, SQLConf.get))
+      DataSource.lookupDataSource(provider, 
SQLConf.get).getConstructor().newInstance()
     } catch {
       case e: Throwable =>
         logError(s"Failed to find data source: $provider when check data 
column names.", e)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 9771ee08b258..2f8fca7cfd73 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -35,7 +35,7 @@ import 
org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, 
quoteIfNeeded, CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, 
ResolveDefaultColumns}
 import 
org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY
 import 
org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
@@ -264,9 +264,8 @@ case class AlterTableAddColumnsCommand(
     }
 
     if (DDLUtils.isDatasourceTable(catalogTable)) {
-      DataSource.newDataSourceInstance(
-          catalogTable.provider.get,
-          DataSource.lookupDataSource(catalogTable.provider.get, conf)) match {
+      DataSource.lookupDataSource(catalogTable.provider.get, conf).
+        getConstructor().newInstance() match {
         // For datasource table, this command can only support the following 
File format.
         // TextFileFormat only default to one column "value"
         // Hive type is already considered as hive serde table, so the logic 
will not
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 9612d8ff24f5..efec44658d51 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -106,14 +106,13 @@ case class DataSource(
     // [[FileDataSourceV2]] will still be used if we call the load()/save() 
method in
     // [[DataFrameReader]]/[[DataFrameWriter]], since they use method 
`lookupDataSource`
     // instead of `providingClass`.
-    DataSource.newDataSourceInstance(className, cls) match {
+    cls.getDeclaredConstructor().newInstance() match {
       case f: FileDataSourceV2 => f.fallbackFileFormat
       case _ => cls
     }
   }
 
-  private[sql] def providingInstance(): Any =
-    DataSource.newDataSourceInstance(className, providingClass)
+  private[sql] def providingInstance(): Any = 
providingClass.getConstructor().newInstance()
 
   private def newHadoopConfiguration(): Configuration =
     sparkSession.sessionState.newHadoopConfWithOptions(options)
@@ -624,15 +623,6 @@ object DataSource extends Logging {
     "org.apache.spark.sql.sources.HadoopFsRelationProvider",
     "org.apache.spark.Logging")
 
-  /** Create the instance of the datasource */
-  def newDataSourceInstance(provider: String, providingClass: Class[_]): Any = 
{
-    providingClass match {
-      case cls if classOf[PythonTableProvider].isAssignableFrom(cls) =>
-        cls.getDeclaredConstructor(classOf[String]).newInstance(provider)
-      case cls => cls.getDeclaredConstructor().newInstance()
-    }
-  }
-
   /** Given a provider name, look up the data source class definition. */
   def lookupDataSource(provider: String, conf: SQLConf): Class[_] = {
     val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) 
match {
@@ -732,9 +722,9 @@ object DataSource extends Logging {
   def lookupDataSourceV2(provider: String, conf: SQLConf): 
Option[TableProvider] = {
     val useV1Sources = 
conf.getConf(SQLConf.USE_V1_SOURCE_LIST).toLowerCase(Locale.ROOT)
       .split(",").map(_.trim)
-    val providingClass = lookupDataSource(provider, conf)
+    val cls = lookupDataSource(provider, conf)
     val instance = try {
-      newDataSourceInstance(provider, providingClass)
+      cls.getDeclaredConstructor().newInstance()
     } catch {
       // Throw the original error from the data source implementation.
       case e: java.lang.reflect.InvocationTargetException => throw e.getCause
@@ -742,8 +732,11 @@ object DataSource extends Logging {
     instance match {
       case d: DataSourceRegister if useV1Sources.contains(d.shortName()) => 
None
       case t: TableProvider
-          if !useV1Sources.contains(
-            providingClass.getCanonicalName.toLowerCase(Locale.ROOT)) =>
+          if 
!useV1Sources.contains(cls.getCanonicalName.toLowerCase(Locale.ROOT)) =>
+        t match {
+          case p: PythonTableProvider => p.setShortName(provider)
+          case _ =>
+        }
         Some(t)
       case _ => None
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index 7c850d1e2890..5e978a900884 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -44,7 +44,16 @@ import org.apache.spark.util.ArrayImplicits._
 /**
  * Data Source V2 wrapper for Python Data Source.
  */
-class PythonTableProvider(shortName: String) extends TableProvider {
+class PythonTableProvider extends TableProvider {
+  private var name: String = _
+  def setShortName(str: String): Unit = {
+    assert(name == null)
+    name = str
+  }
+  private def shortName: String = {
+    assert(name != null)
+    name
+  }
   private var dataSourceInPython: PythonDataSourceCreationResult = _
   private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
   private lazy val source: UserDefinedPythonDataSource =
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index c93ca632d3c7..1a69678c2f54 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -156,9 +156,8 @@ final class DataStreamReader private[sql](sparkSession: 
SparkSession) extends Lo
       extraOptions + ("path" -> path.get)
     }
 
-    val ds = DataSource.newDataSourceInstance(
-      source,
-      DataSource.lookupDataSource(source, sparkSession.sessionState.conf))
+    val ds = DataSource.lookupDataSource(source, 
sparkSession.sessionState.conf).
+      getConstructor().newInstance()
     // We need to generate the V1 data source so we can pass it to the V2 
relation as a shim.
     // We can't be sure at this point whether we'll actually want to use V2, 
since we don't know the
     // writer or whether the query is continuous.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 7202f69ab1bf..95aa2f8c7a4e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -382,7 +382,7 @@ final class DataStreamWriter[T] private[sql](ds: 
Dataset[T]) {
       }
 
       val sink = if (classOf[TableProvider].isAssignableFrom(cls) && 
!useV1Source) {
-        val provider = DataSource.newDataSourceInstance(source, 
cls).asInstanceOf[TableProvider]
+        val provider = 
cls.getConstructor().newInstance().asInstanceOf[TableProvider]
         val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
           source = provider, conf = df.sparkSession.sessionState.conf)
         val finalOptions = sessionOptions.filter { case (k, _) => 
!optionsWithPath.contains(k) } ++


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to