This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 62fc27d79d5c [SPARK-46423][PYTHON][SQL] Make the Python Data Source instance at DataSource.lookupDataSourceV2 62fc27d79d5c is described below commit 62fc27d79d5ccce476671a7a664272c718024617 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Fri Dec 15 15:18:20 2023 -0800 [SPARK-46423][PYTHON][SQL] Make the Python Data Source instance at DataSource.lookupDataSourceV2 ### What changes were proposed in this pull request? This PR is a kind of a followup of https://github.com/apache/spark/pull/44305 that proposes to create Python Data Source instance at `DataSource.lookupDataSourceV2` ### Why are the changes needed? Semantically the instance has to be ready at `DataSource.lookupDataSourceV2` level instead of after that. It's more consistent as well. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests should cover. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44374 from HyukjinKwon/SPARK-46423. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../scala/org/apache/spark/sql/SparkSession.scala | 2 +- .../apache/spark/sql/execution/command/ddl.scala | 6 ++---- .../spark/sql/execution/command/tables.scala | 7 +++--- .../sql/execution/datasources/DataSource.scala | 25 ++++++++-------------- .../python/UserDefinedPythonDataSource.scala | 11 +++++++++- .../spark/sql/streaming/DataStreamReader.scala | 5 ++--- .../spark/sql/streaming/DataStreamWriter.scala | 2 +- 7 files changed, 28 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 44a4d82c1dac..15eeca87dcf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -780,7 +780,7 @@ class SparkSession private( DataSource.lookupDataSource(runner, sessionState.conf) match { case source if classOf[ExternalCommandRunner].isAssignableFrom(source) => Dataset.ofRows(self, ExternalCommandExecutor( - DataSource.newDataSourceInstance(runner, source) + source.getDeclaredConstructor().newInstance() .asInstanceOf[ExternalCommandRunner], command, options)) case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 199c8728a5c9..dc1c5b3fd580 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM import org.apache.spark.sql.connector.catalog.SupportsNamespaces._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryExecutionErrors.hiveTableWithAnsiIntervalsError -import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types._ @@ -1025,9 +1025,7 @@ object DDLUtils extends Logging { def checkDataColNames(provider: String, schema: StructType): Unit = { val source = try { - DataSource.newDataSourceInstance( - provider, - DataSource.lookupDataSource(provider, SQLConf.get)) + DataSource.lookupDataSource(provider, SQLConf.get).getConstructor().newInstance() } catch { case e: Throwable => logError(s"Failed to find data source: $provider when check data column names.", e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 9771ee08b258..2f8fca7cfd73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIfNeeded, CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -264,9 +264,8 @@ case class AlterTableAddColumnsCommand( } if (DDLUtils.isDatasourceTable(catalogTable)) { - DataSource.newDataSourceInstance( - catalogTable.provider.get, - DataSource.lookupDataSource(catalogTable.provider.get, conf)) match { + DataSource.lookupDataSource(catalogTable.provider.get, conf). + getConstructor().newInstance() match { // For datasource table, this command can only support the following File format. // TextFileFormat only default to one column "value" // Hive type is already considered as hive serde table, so the logic will not diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 9612d8ff24f5..efec44658d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -106,14 +106,13 @@ case class DataSource( // [[FileDataSourceV2]] will still be used if we call the load()/save() method in // [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource` // instead of `providingClass`. - DataSource.newDataSourceInstance(className, cls) match { + cls.getDeclaredConstructor().newInstance() match { case f: FileDataSourceV2 => f.fallbackFileFormat case _ => cls } } - private[sql] def providingInstance(): Any = - DataSource.newDataSourceInstance(className, providingClass) + private[sql] def providingInstance(): Any = providingClass.getConstructor().newInstance() private def newHadoopConfiguration(): Configuration = sparkSession.sessionState.newHadoopConfWithOptions(options) @@ -624,15 +623,6 @@ object DataSource extends Logging { "org.apache.spark.sql.sources.HadoopFsRelationProvider", "org.apache.spark.Logging") - /** Create the instance of the datasource */ - def newDataSourceInstance(provider: String, providingClass: Class[_]): Any = { - providingClass match { - case cls if classOf[PythonTableProvider].isAssignableFrom(cls) => - cls.getDeclaredConstructor(classOf[String]).newInstance(provider) - case cls => cls.getDeclaredConstructor().newInstance() - } - } - /** Given a provider name, look up the data source class definition. */ def lookupDataSource(provider: String, conf: SQLConf): Class[_] = { val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) match { @@ -732,9 +722,9 @@ object DataSource extends Logging { def lookupDataSourceV2(provider: String, conf: SQLConf): Option[TableProvider] = { val useV1Sources = conf.getConf(SQLConf.USE_V1_SOURCE_LIST).toLowerCase(Locale.ROOT) .split(",").map(_.trim) - val providingClass = lookupDataSource(provider, conf) + val cls = lookupDataSource(provider, conf) val instance = try { - newDataSourceInstance(provider, providingClass) + cls.getDeclaredConstructor().newInstance() } catch { // Throw the original error from the data source implementation. case e: java.lang.reflect.InvocationTargetException => throw e.getCause @@ -742,8 +732,11 @@ object DataSource extends Logging { instance match { case d: DataSourceRegister if useV1Sources.contains(d.shortName()) => None case t: TableProvider - if !useV1Sources.contains( - providingClass.getCanonicalName.toLowerCase(Locale.ROOT)) => + if !useV1Sources.contains(cls.getCanonicalName.toLowerCase(Locale.ROOT)) => + t match { + case p: PythonTableProvider => p.setShortName(provider) + case _ => + } Some(t) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index 7c850d1e2890..5e978a900884 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -44,7 +44,16 @@ import org.apache.spark.util.ArrayImplicits._ /** * Data Source V2 wrapper for Python Data Source. */ -class PythonTableProvider(shortName: String) extends TableProvider { +class PythonTableProvider extends TableProvider { + private var name: String = _ + def setShortName(str: String): Unit = { + assert(name == null) + name = str + } + private def shortName: String = { + assert(name != null) + name + } private var dataSourceInPython: PythonDataSourceCreationResult = _ private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) private lazy val source: UserDefinedPythonDataSource = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index c93ca632d3c7..1a69678c2f54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -156,9 +156,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo extraOptions + ("path" -> path.get) } - val ds = DataSource.newDataSourceInstance( - source, - DataSource.lookupDataSource(source, sparkSession.sessionState.conf)) + val ds = DataSource.lookupDataSource(source, sparkSession.sessionState.conf). + getConstructor().newInstance() // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. // We can't be sure at this point whether we'll actually want to use V2, since we don't know the // writer or whether the query is continuous. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 7202f69ab1bf..95aa2f8c7a4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -382,7 +382,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) { - val provider = DataSource.newDataSourceInstance(source, cls).asInstanceOf[TableProvider] + val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] val sessionOptions = DataSourceV2Utils.extractSessionConfigs( source = provider, conf = df.sparkSession.sessionState.conf) val finalOptions = sessionOptions.filter { case (k, _) => !optionsWithPath.contains(k) } ++ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org