This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4b40920e331 [SPARK-41713][SQL] Make CTAS hold a nested execution for
data writing
4b40920e331 is described below
commit 4b40920e33176fc8b18380703e4dcf4d16824094
Author: ulysses-you <[email protected]>
AuthorDate: Wed Dec 28 17:11:59 2022 +0800
[SPARK-41713][SQL] Make CTAS hold a nested execution for data writing
### What changes were proposed in this pull request?
This pr aims to make ctas use a nested execution instead of running data
writing cmmand.
So, we can clean up ctas itself to remove the unnecessary v1write
information. Now, the v1writes only have two implementation:
`InsertIntoHadoopFsRelationCommand` and `InsertIntoHiveTable`
### Why are the changes needed?
Make v1writes code clear.
```sql
EXPLAIN FORMATTED CREATE TABLE t2 USING PARQUET AS SELECT * FROM t;
== Physical Plan ==
Execute CreateDataSourceTableAsSelectCommand (1)
+- CreateDataSourceTableAsSelectCommand (2)
+- Project (5)
+- SubqueryAlias (4)
+- LogicalRelation (3)
(1) Execute CreateDataSourceTableAsSelectCommand
Output: []
(2) CreateDataSourceTableAsSelectCommand
Arguments: `spark_catalog`.`default`.`t2`, ErrorIfExists, [c1, c2]
(3) LogicalRelation
Arguments: parquet, [c1#11, c2#12], `spark_catalog`.`default`.`t`,
org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe, false
(4) SubqueryAlias
Arguments: spark_catalog.default.t
(5) Project
Arguments: [c1#11, c2#12]
```
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
improve existed test
Closes #39220 from ulysses-you/SPARK-41713.
Authored-by: ulysses-you <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../execution/command/createDataSourceTables.scala | 40 +++------------
.../sql/execution/datasources/DataSource.scala | 48 ++++++------------
.../spark/sql/execution/datasources/V1Writes.scala | 8 +--
.../scala/org/apache/spark/sql/ExplainSuite.scala | 7 ++-
.../adaptive/AdaptiveQueryExecSuite.scala | 58 ++++++++++++++--------
.../datasources/V1WriteCommandSuite.scala | 17 +++----
.../sql/execution/metric/SQLMetricsSuite.scala | 41 ++++++++++-----
.../spark/sql/util/DataFrameCallbackSuite.scala | 12 +++--
.../execution/CreateHiveTableAsSelectCommand.scala | 46 +++++------------
.../sql/hive/execution/HiveExplainSuite.scala | 16 ++----
.../spark/sql/hive/execution/SQLMetricsSuite.scala | 49 +++++++++++-------
11 files changed, 159 insertions(+), 183 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
index 9bf9f43829e..bf14ef14cf4 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala
@@ -20,13 +20,11 @@ package org.apache.spark.sql.execution.command
import java.net.URI
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.{CommandExecutionMode, SparkPlan}
+import org.apache.spark.sql.execution.CommandExecutionMode
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
@@ -143,29 +141,11 @@ case class CreateDataSourceTableAsSelectCommand(
mode: SaveMode,
query: LogicalPlan,
outputColumnNames: Seq[String])
- extends V1WriteCommand {
-
- override def fileFormatProvider: Boolean = {
- table.provider.forall { provider =>
- classOf[FileFormat].isAssignableFrom(DataSource.providingClass(provider,
conf))
- }
- }
-
- override lazy val partitionColumns: Seq[Attribute] = {
- val unresolvedPartitionColumns =
table.partitionColumnNames.map(UnresolvedAttribute.quoted)
- DataSource.resolvePartitionColumns(
- unresolvedPartitionColumns,
- outputColumns,
- query,
- SparkSession.active.sessionState.conf.resolver)
- }
-
- override def requiredOrdering: Seq[SortOrder] = {
- val options = table.storage.properties
- V1WritesUtils.getSortOrder(outputColumns, partitionColumns,
table.bucketSpec, options)
- }
+ extends LeafRunnableCommand {
+ assert(query.resolved)
+ override def innerChildren: Seq[LogicalPlan] = query :: Nil
- override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
assert(table.tableType != CatalogTableType.VIEW)
assert(table.provider.isDefined)
@@ -187,7 +167,7 @@ case class CreateDataSourceTableAsSelectCommand(
}
saveDataIntoTable(
- sparkSession, table, table.storage.locationUri, child,
SaveMode.Append, tableExists = true)
+ sparkSession, table, table.storage.locationUri, SaveMode.Append,
tableExists = true)
} else {
table.storage.locationUri.foreach { p =>
DataWritingCommand.assertEmptyRootPath(p, mode,
sparkSession.sessionState.newHadoopConf)
@@ -200,7 +180,7 @@ case class CreateDataSourceTableAsSelectCommand(
table.storage.locationUri
}
val result = saveDataIntoTable(
- sparkSession, table, tableLocation, child, SaveMode.Overwrite,
tableExists = false)
+ sparkSession, table, tableLocation, SaveMode.Overwrite, tableExists =
false)
val tableSchema = CharVarcharUtils.getRawSchema(result.schema,
sessionState.conf)
val newTable = table.copy(
storage = table.storage.copy(locationUri = tableLocation),
@@ -232,7 +212,6 @@ case class CreateDataSourceTableAsSelectCommand(
session: SparkSession,
table: CatalogTable,
tableLocation: Option[URI],
- physicalPlan: SparkPlan,
mode: SaveMode,
tableExists: Boolean): BaseRelation = {
// Create the relation based on the input logical plan: `query`.
@@ -246,14 +225,11 @@ case class CreateDataSourceTableAsSelectCommand(
catalogTable = if (tableExists) Some(table) else None)
try {
- dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan,
metrics)
+ dataSource.writeAndRead(mode, query, outputColumnNames)
} catch {
case ex: AnalysisException =>
logError(s"Failed to write to table
${table.identifier.unquotedString}", ex)
throw ex
}
}
-
- override protected def withNewChildInternal(newChild: LogicalPlan):
LogicalPlan =
- copy(query = newChild)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 3d8eb9bc8a8..edbdd6bbc67 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, TypeUtils}
import org.apache.spark.sql.connector.catalog.TableProvider
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
-import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.DataWritingCommand
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
@@ -45,7 +44,6 @@ import
org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
-import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider,
TextSocketSourceProvider}
import org.apache.spark.sql.internal.SQLConf
@@ -97,8 +95,19 @@ case class DataSource(
case class SourceInfo(name: String, schema: StructType, partitionColumns:
Seq[String])
- lazy val providingClass: Class[_] =
- DataSource.providingClass(className, sparkSession.sessionState.conf)
+ lazy val providingClass: Class[_] = {
+ val cls = DataSource.lookupDataSource(className,
sparkSession.sessionState.conf)
+ // `providingClass` is used for resolving data source relation for catalog
tables.
+ // As now catalog for data source V2 is under development, here we fall
back all the
+ // [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog
works.
+ // [[FileDataSourceV2]] will still be used if we call the load()/save()
method in
+ // [[DataFrameReader]]/[[DataFrameWriter]], since they use method
`lookupDataSource`
+ // instead of `providingClass`.
+ cls.newInstance() match {
+ case f: FileDataSourceV2 => f.fallbackFileFormat
+ case _ => cls
+ }
+ }
private[sql] def providingInstance(): Any =
providingClass.getConstructor().newInstance()
@@ -483,17 +492,11 @@ case class DataSource(
* @param outputColumnNames The original output column names of the input
query plan. The
* optimizer may not preserve the output column's
names' case, so we need
* this parameter instead of `data.output`.
- * @param physicalPlan The physical plan of the input query plan. We should
run the writing
- * command with this physical plan instead of creating a
new physical plan,
- * so that the metrics can be correctly linked to the
given physical plan and
- * shown in the web UI.
*/
def writeAndRead(
mode: SaveMode,
data: LogicalPlan,
- outputColumnNames: Seq[String],
- physicalPlan: SparkPlan,
- metrics: Map[String, SQLMetric]): BaseRelation = {
+ outputColumnNames: Seq[String]): BaseRelation = {
val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data,
outputColumnNames)
providingInstance() match {
case dataSource: CreatableRelationProvider =>
@@ -503,13 +506,8 @@ case class DataSource(
case format: FileFormat =>
disallowWritingIntervals(outputColumns.map(_.dataType),
forbidAnsiIntervals = false)
val cmd = planForWritingFileFormat(format, mode, data)
- val resolvedPartCols =
- DataSource.resolvePartitionColumns(cmd.partitionColumns,
outputColumns, data, equality)
- val resolved = cmd.copy(
- partitionColumns = resolvedPartCols,
- outputColumnNames = outputColumnNames)
- resolved.run(sparkSession, physicalPlan)
- DataWritingCommand.propogateMetrics(sparkSession.sparkContext,
resolved, metrics)
+ val qe = sparkSession.sessionState.executePlan(cmd)
+ qe.assertCommandExecuted()
// Replace the schema with that of the DataFrame we just wrote out to
avoid re-inferring
copy(userSpecifiedSchema =
Some(outputColumns.toStructType.asNullable)).resolveRelation()
case _ => throw new IllegalStateException(
@@ -832,18 +830,4 @@ object DataSource extends Logging {
}
}
}
-
- def providingClass(className: String, conf: SQLConf): Class[_] = {
- val cls = DataSource.lookupDataSource(className, conf)
- // `providingClass` is used for resolving data source relation for catalog
tables.
- // As now catalog for data source V2 is under development, here we fall
back all the
- // [[FileDataSourceV2]] to [[FileFormat]] to guarantee the current catalog
works.
- // [[FileDataSourceV2]] will still be used if we call the load()/save()
method in
- // [[DataFrameReader]]/[[DataFrameWriter]], since they use method
`lookupDataSource`
- // instead of `providingClass`.
- cls.newInstance() match {
- case f: FileDataSourceV2 => f.fallbackFileFormat
- case _ => cls
- }
- }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
index e9f6e3df785..3ed04e5bd6d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala
@@ -31,11 +31,6 @@ import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String
trait V1WriteCommand extends DataWritingCommand {
- /**
- * Return if the provider is [[FileFormat]]
- */
- def fileFormatProvider: Boolean = true
-
/**
* Specify the partition columns of the V1 write command.
*/
@@ -58,8 +53,7 @@ object V1Writes extends Rule[LogicalPlan] with SQLConfHelper {
override def apply(plan: LogicalPlan): LogicalPlan = {
if (conf.plannedWriteEnabled) {
plan.transformUp {
- case write: V1WriteCommand if write.fileFormatProvider &&
- !write.child.isInstanceOf[WriteFiles] =>
+ case write: V1WriteCommand if !write.child.isInstanceOf[WriteFiles] =>
val newQuery = prepareQuery(write, write.query)
val attrMap = AttributeMap(write.query.output.zip(newQuery.output))
val newChild = WriteFiles(newQuery)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
index b5353455dc2..9a75cc5ff8f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
@@ -250,7 +250,12 @@ class ExplainSuite extends ExplainSuiteHelper with
DisableAdaptiveExecutionSuite
withTable("temptable") {
val df = sql("create table temptable using parquet as select * from
range(2)")
withNormalizedExplain(df, SimpleMode) { normalizedOutput =>
-
assert("Create\\w*?TableAsSelectCommand".r.findAllMatchIn(normalizedOutput).length
== 1)
+ // scalastyle:off
+ // == Physical Plan ==
+ // Execute CreateDataSourceTableAsSelectCommand
+ // +- CreateDataSourceTableAsSelectCommand
`spark_catalog`.`default`.`temptable`, ErrorIfExists, Project [id#5L], [id]
+ // scalastyle:on
+
assert("Create\\w*?TableAsSelectCommand".r.findAllMatchIn(normalizedOutput).length
== 2)
}
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 88baf76ba7a..1f10ff36acb 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.scheduler.{SparkListener,
SparkListenerEvent, SparkListe
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
-import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec,
LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution,
ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnionExec}
+import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec,
PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec,
ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
@@ -37,7 +37,7 @@ import
org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM,
ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.joins.{BaseJoinExec,
BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec,
ShuffledJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
-import
org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
+import
org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate,
SparkListenerSQLExecutionStart}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
@@ -1150,18 +1150,31 @@ class AdaptiveQueryExecSuite
SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true",
SQLConf.PLANNED_WRITE_ENABLED.key -> enabled.toString) {
withTable("t1") {
- val df = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col")
- val plan = df.queryExecution.executedPlan
- assert(plan.isInstanceOf[CommandResultExec])
- val commandPhysicalPlan =
plan.asInstanceOf[CommandResultExec].commandPhysicalPlan
- if (enabled) {
- assert(commandPhysicalPlan.isInstanceOf[AdaptiveSparkPlanExec])
- assert(commandPhysicalPlan.asInstanceOf[AdaptiveSparkPlanExec]
- .executedPlan.isInstanceOf[DataWritingCommandExec])
- } else {
- assert(commandPhysicalPlan.isInstanceOf[DataWritingCommandExec])
- assert(commandPhysicalPlan.asInstanceOf[DataWritingCommandExec]
- .child.isInstanceOf[AdaptiveSparkPlanExec])
+ var checkDone = false
+ val listener = new SparkListener {
+ override def onOtherEvent(event: SparkListenerEvent): Unit = {
+ event match {
+ case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) =>
+ if (enabled) {
+ assert(planInfo.nodeName == "AdaptiveSparkPlan")
+ assert(planInfo.children.size == 1)
+ assert(planInfo.children.head.nodeName ==
+ "Execute InsertIntoHadoopFsRelationCommand")
+ } else {
+ assert(planInfo.nodeName == "Execute
InsertIntoHadoopFsRelationCommand")
+ }
+ checkDone = true
+ case _ => // ignore other events
+ }
+ }
+ }
+ spark.sparkContext.addSparkListener(listener)
+ try {
+ sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect()
+ spark.sparkContext.listenerBus.waitUntilEmpty()
+ assert(checkDone)
+ } finally {
+ spark.sparkContext.removeSparkListener(listener)
}
}
}
@@ -1209,16 +1222,12 @@ class AdaptiveQueryExecSuite
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
withTable("t1") {
- var checkDone = false
+ var commands: Seq[SparkPlanInfo] = Seq.empty
val listener = new SparkListener {
override def onOtherEvent(event: SparkListenerEvent): Unit = {
event match {
- case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) =>
- assert(planInfo.nodeName == "AdaptiveSparkPlan")
- assert(planInfo.children.size == 1)
- assert(planInfo.children.head.nodeName ==
- "Execute CreateDataSourceTableAsSelectCommand")
- checkDone = true
+ case start: SparkListenerSQLExecutionStart =>
+ commands = commands ++ Seq(start.sparkPlanInfo)
case _ => // ignore other events
}
}
@@ -1227,7 +1236,12 @@ class AdaptiveQueryExecSuite
try {
sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect()
spark.sparkContext.listenerBus.waitUntilEmpty()
- assert(checkDone)
+ assert(commands.size == 3)
+ assert(commands.head.nodeName == "Execute
CreateDataSourceTableAsSelectCommand")
+ assert(commands(1).nodeName == "AdaptiveSparkPlan")
+ assert(commands(1).children.size == 1)
+ assert(commands(1).children.head.nodeName == "Execute
InsertIntoHadoopFsRelationCommand")
+ assert(commands(2).nodeName == "CommandResult")
} finally {
spark.sparkContext.removeSparkListener(listener)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
index eb2aa09e075..e9c5c77e6d9 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala
@@ -65,7 +65,7 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils {
override def onSuccess(funcName: String, qe: QueryExecution, durationNs:
Long): Unit = {
qe.optimizedPlan match {
case w: V1WriteCommand =>
- if (hasLogicalSort) {
+ if (hasLogicalSort && conf.getConf(SQLConf.PLANNED_WRITE_ENABLED))
{
assert(w.query.isInstanceOf[WriteFiles])
optimizedPlan = w.query.asInstanceOf[WriteFiles].child
} else {
@@ -86,16 +86,15 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils {
sparkContext.listenerBus.waitUntilEmpty()
+ assert(optimizedPlan != null)
// Check whether a logical sort node is at the top of the logical plan of
the write query.
- if (optimizedPlan != null) {
- assert(optimizedPlan.isInstanceOf[Sort] == hasLogicalSort,
- s"Expect hasLogicalSort: $hasLogicalSort, Actual:
${optimizedPlan.isInstanceOf[Sort]}")
+ assert(optimizedPlan.isInstanceOf[Sort] == hasLogicalSort,
+ s"Expect hasLogicalSort: $hasLogicalSort, Actual:
${optimizedPlan.isInstanceOf[Sort]}")
- // Check empty2null conversion.
- val empty2nullExpr = optimizedPlan.exists(p =>
V1WritesUtils.hasEmptyToNull(p.expressions))
- assert(empty2nullExpr == hasEmpty2Null,
- s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr.
Plan:\n$optimizedPlan")
- }
+ // Check empty2null conversion.
+ val empty2nullExpr = optimizedPlan.exists(p =>
V1WritesUtils.hasEmptyToNull(p.expressions))
+ assert(empty2nullExpr == hasEmpty2Null,
+ s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr.
Plan:\n$optimizedPlan")
spark.listenerManager.unregister(listener)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 1f20fb62d37..424052df289 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -34,12 +34,13 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.command.DataWritingCommandExec
-import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker,
SQLHadoopMapReduceCommitProtocol}
+import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker,
InsertIntoHadoopFsRelationCommand, SQLHadoopMapReduceCommitProtocol,
V1WriteCommand}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec,
ShuffledHashJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.util.{AccumulatorContext, JsonProtocol}
// Disable AQE because metric info is different with AQE on/off
@@ -832,18 +833,32 @@ class SQLMetricsSuite extends SharedSparkSession with
SQLMetricsTestUtils
test("SPARK-34567: Add metrics for CTAS operator") {
withTable("t") {
- val df = sql("CREATE TABLE t USING PARQUET AS SELECT 1 as a")
- assert(df.queryExecution.executedPlan.isInstanceOf[CommandResultExec])
- val commandResultExec =
df.queryExecution.executedPlan.asInstanceOf[CommandResultExec]
- val dataWritingCommandExec =
-
commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec]
- val createTableAsSelect = dataWritingCommandExec.cmd
- assert(createTableAsSelect.metrics.contains("numFiles"))
- assert(createTableAsSelect.metrics("numFiles").value == 1)
- assert(createTableAsSelect.metrics.contains("numOutputBytes"))
- assert(createTableAsSelect.metrics("numOutputBytes").value > 0)
- assert(createTableAsSelect.metrics.contains("numOutputRows"))
- assert(createTableAsSelect.metrics("numOutputRows").value == 1)
+ var v1WriteCommand: V1WriteCommand = null
+ val listener = new QueryExecutionListener {
+ override def onFailure(f: String, qe: QueryExecution, e: Exception):
Unit = {}
+ override def onSuccess(funcName: String, qe: QueryExecution, duration:
Long): Unit = {
+ qe.executedPlan match {
+ case dataWritingCommandExec: DataWritingCommandExec =>
+ val createTableAsSelect = dataWritingCommandExec.cmd
+ v1WriteCommand =
createTableAsSelect.asInstanceOf[InsertIntoHadoopFsRelationCommand]
+ case _ =>
+ }
+ }
+ }
+ spark.listenerManager.register(listener)
+ try {
+ val df = sql("CREATE TABLE t USING PARQUET AS SELECT 1 as a")
+ sparkContext.listenerBus.waitUntilEmpty()
+ assert(v1WriteCommand != null)
+ assert(v1WriteCommand.metrics.contains("numFiles"))
+ assert(v1WriteCommand.metrics("numFiles").value == 1)
+ assert(v1WriteCommand.metrics.contains("numOutputBytes"))
+ assert(v1WriteCommand.metrics("numOutputBytes").value > 0)
+ assert(v1WriteCommand.metrics.contains("numOutputRows"))
+ assert(v1WriteCommand.metrics("numOutputRows").value == 1)
+ } finally {
+ spark.listenerManager.unregister(listener)
+ }
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index dd6acc983b7..2fc1f10d3ea 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -217,10 +217,14 @@ class DataFrameCallbackSuite extends QueryTest
withTable("tab") {
spark.range(10).select($"id", $"id" % 5 as
"p").write.partitionBy("p").saveAsTable("tab")
sparkContext.listenerBus.waitUntilEmpty()
- assert(commands.length == 5)
- assert(commands(4)._1 == "command")
- assert(commands(4)._2.isInstanceOf[CreateDataSourceTableAsSelectCommand])
- assert(commands(4)._2.asInstanceOf[CreateDataSourceTableAsSelectCommand]
+ // CTAS would derive 3 query executions
+ // 1. CreateDataSourceTableAsSelectCommand
+ // 2. InsertIntoHadoopFsRelationCommand
+ // 3. CommandResultExec
+ assert(commands.length == 6)
+ assert(commands(5)._1 == "command")
+ assert(commands(5)._2.isInstanceOf[CreateDataSourceTableAsSelectCommand])
+ assert(commands(5)._2.asInstanceOf[CreateDataSourceTableAsSelectCommand]
.table.partitionColumnNames == Seq("p"))
}
diff --git
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
index ce320775027..4dfb2cf65eb 100644
---
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
+++
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
@@ -21,43 +21,26 @@ import scala.util.control.NonFatal
import org.apache.spark.sql.{Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog}
-import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.command.{DataWritingCommand, DDLUtils}
-import org.apache.spark.sql.execution.datasources.{HadoopFsRelation,
InsertIntoHadoopFsRelationCommand, LogicalRelation, V1WriteCommand,
V1WritesUtils}
+import org.apache.spark.sql.execution.command.{DataWritingCommand, DDLUtils,
LeafRunnableCommand}
+import org.apache.spark.sql.execution.datasources.{HadoopFsRelation,
InsertIntoHadoopFsRelationCommand, LogicalRelation}
import org.apache.spark.sql.hive.HiveSessionCatalog
import org.apache.spark.util.Utils
-trait CreateHiveTableAsSelectBase extends V1WriteCommand with
V1WritesHiveUtils {
+trait CreateHiveTableAsSelectBase extends LeafRunnableCommand {
val tableDesc: CatalogTable
val query: LogicalPlan
val outputColumnNames: Seq[String]
val mode: SaveMode
- protected val tableIdentifier = tableDesc.identifier
+ assert(query.resolved)
+ override def innerChildren: Seq[LogicalPlan] = query :: Nil
- override lazy val partitionColumns: Seq[Attribute] = {
- // If the table does not exist the schema should always be empty.
- val table = if (tableDesc.schema.isEmpty) {
- val tableSchema =
CharVarcharUtils.getRawSchema(outputColumns.toStructType, conf)
- tableDesc.copy(schema = tableSchema)
- } else {
- tableDesc
- }
- // For CTAS, there is no static partition values to insert.
- val partition = tableDesc.partitionColumnNames.map(_ -> None).toMap
- getDynamicPartitionColumns(table, partition, query)
- }
-
- override def requiredOrdering: Seq[SortOrder] = {
- val options = getOptionsWithHiveBucketWrite(tableDesc.bucketSpec)
- V1WritesUtils.getSortOrder(outputColumns, partitionColumns,
tableDesc.bucketSpec, options)
- }
+ protected val tableIdentifier = tableDesc.identifier
- override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
+ override def run(sparkSession: SparkSession): Seq[Row] = {
val catalog = sparkSession.sessionState.catalog
val tableExists = catalog.tableExists(tableIdentifier)
@@ -74,8 +57,8 @@ trait CreateHiveTableAsSelectBase extends V1WriteCommand with
V1WritesHiveUtils
}
val command = getWritingCommand(catalog, tableDesc, tableExists = true)
- command.run(sparkSession, child)
- DataWritingCommand.propogateMetrics(sparkSession.sparkContext, command,
metrics)
+ val qe = sparkSession.sessionState.executePlan(command)
+ qe.assertCommandExecuted()
} else {
tableDesc.storage.locationUri.foreach { p =>
DataWritingCommand.assertEmptyRootPath(p, mode,
sparkSession.sessionState.newHadoopConf)
@@ -83,6 +66,7 @@ trait CreateHiveTableAsSelectBase extends V1WriteCommand with
V1WritesHiveUtils
// TODO ideally, we should get the output data ready first and then
// add the relation into catalog, just in case of failure occurs while
data
// processing.
+ val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(query,
outputColumnNames)
val tableSchema = CharVarcharUtils.getRawSchema(
outputColumns.toStructType, sparkSession.sessionState.conf)
assert(tableDesc.schema.isEmpty)
@@ -93,8 +77,8 @@ trait CreateHiveTableAsSelectBase extends V1WriteCommand with
V1WritesHiveUtils
// Read back the metadata of the table which was created just now.
val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier)
val command = getWritingCommand(catalog, createdTableMeta, tableExists
= false)
- command.run(sparkSession, child)
- DataWritingCommand.propogateMetrics(sparkSession.sparkContext,
command, metrics)
+ val qe = sparkSession.sessionState.executePlan(command)
+ qe.assertCommandExecuted()
} catch {
case NonFatal(e) =>
// drop the created table.
@@ -154,9 +138,6 @@ case class CreateHiveTableAsSelectCommand(
override def writingCommandClassName: String =
Utils.getSimpleName(classOf[InsertIntoHiveTable])
-
- override protected def withNewChildInternal(
- newChild: LogicalPlan): CreateHiveTableAsSelectCommand = copy(query =
newChild)
}
/**
@@ -204,7 +185,4 @@ case class OptimizedCreateHiveTableAsSelectCommand(
override def writingCommandClassName: String =
Utils.getSimpleName(classOf[InsertIntoHadoopFsRelationCommand])
-
- override protected def withNewChildInternal(
- newChild: LogicalPlan): OptimizedCreateHiveTableAsSelectCommand =
copy(query = newChild)
}
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
index 85c2cd53957..258b101dd21 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
@@ -102,10 +102,8 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils
with TestHiveSingleto
test("explain create table command") {
checkKeywordsExist(sql("explain create table temp__b using hive as select
* from src limit 2"),
- "== Physical Plan ==",
- "InsertIntoHiveTable",
- "Limit",
- "src")
+ "== Physical Plan ==",
+ "CreateHiveTableAsSelect")
checkKeywordsExist(
sql("explain extended create table temp__b using hive as select * from
src limit 2"),
@@ -113,10 +111,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils
with TestHiveSingleto
"== Analyzed Logical Plan ==",
"== Optimized Logical Plan ==",
"== Physical Plan ==",
- "CreateHiveTableAsSelect",
- "InsertIntoHiveTable",
- "Limit",
- "src")
+ "CreateHiveTableAsSelect")
checkKeywordsExist(sql(
"""
@@ -131,10 +126,7 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils
with TestHiveSingleto
"== Analyzed Logical Plan ==",
"== Optimized Logical Plan ==",
"== Physical Plan ==",
- "CreateHiveTableAsSelect",
- "InsertIntoHiveTable",
- "Limit",
- "src")
+ "CreateHiveTableAsSelect")
}
test("explain output of physical plan should contain proper codegen stage
ID",
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala
index 7f6272666a6..c5a84b930a9 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala
@@ -17,12 +17,14 @@
package org.apache.spark.sql.hive.execution
-import org.apache.spark.sql.execution.CommandResultExec
+import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
import org.apache.spark.sql.execution.command.DataWritingCommandExec
+import
org.apache.spark.sql.execution.datasources.{InsertIntoHadoopFsRelationCommand,
V1WriteCommand}
import org.apache.spark.sql.execution.metric.SQLMetricsTestUtils
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.tags.SlowHiveTest
// Disable AQE because metric info is different with AQE on/off
@@ -44,23 +46,36 @@ class SQLMetricsSuite extends SQLMetricsTestUtils with
TestHiveSingleton
Seq(false, true).foreach { canOptimized =>
withSQLConf(HiveUtils.CONVERT_METASTORE_CTAS.key ->
canOptimized.toString) {
withTable("t") {
- val df = sql(s"CREATE TABLE t STORED AS PARQUET AS SELECT 1 as a")
-
assert(df.queryExecution.executedPlan.isInstanceOf[CommandResultExec])
- val commandResultExec =
df.queryExecution.executedPlan.asInstanceOf[CommandResultExec]
- val dataWritingCommandExec =
-
commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec]
- val createTableAsSelect = dataWritingCommandExec.cmd
- if (canOptimized) {
-
assert(createTableAsSelect.isInstanceOf[OptimizedCreateHiveTableAsSelectCommand])
- } else {
-
assert(createTableAsSelect.isInstanceOf[CreateHiveTableAsSelectCommand])
+ var v1WriteCommand: V1WriteCommand = null
+ val listener = new QueryExecutionListener {
+ override def onFailure(f: String, qe: QueryExecution, e:
Exception): Unit = {}
+ override def onSuccess(funcName: String, qe: QueryExecution,
duration: Long): Unit = {
+ qe.executedPlan match {
+ case dataWritingCommandExec: DataWritingCommandExec =>
+ val createTableAsSelect = dataWritingCommandExec.cmd
+ v1WriteCommand = if (canOptimized) {
+
createTableAsSelect.asInstanceOf[InsertIntoHadoopFsRelationCommand]
+ } else {
+ createTableAsSelect.asInstanceOf[InsertIntoHiveTable]
+ }
+ case _ =>
+ }
+ }
+ }
+ spark.listenerManager.register(listener)
+ try {
+ sql(s"CREATE TABLE t STORED AS PARQUET AS SELECT 1 as a")
+ sparkContext.listenerBus.waitUntilEmpty()
+ assert(v1WriteCommand != null)
+ assert(v1WriteCommand.metrics.contains("numFiles"))
+ assert(v1WriteCommand.metrics("numFiles").value == 1)
+ assert(v1WriteCommand.metrics.contains("numOutputBytes"))
+ assert(v1WriteCommand.metrics("numOutputBytes").value > 0)
+ assert(v1WriteCommand.metrics.contains("numOutputRows"))
+ assert(v1WriteCommand.metrics("numOutputRows").value == 1)
+ } finally {
+ spark.listenerManager.unregister(listener)
}
- assert(createTableAsSelect.metrics.contains("numFiles"))
- assert(createTableAsSelect.metrics("numFiles").value == 1)
- assert(createTableAsSelect.metrics.contains("numOutputBytes"))
- assert(createTableAsSelect.metrics("numOutputBytes").value > 0)
- assert(createTableAsSelect.metrics.contains("numOutputRows"))
- assert(createTableAsSelect.metrics("numOutputRows").value == 1)
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]