Repository: spark
Updated Branches:
  refs/heads/branch-2.3 740606eb8 -> fa552c3c1


[SPARK-24867][SQL] Add AnalysisBarrier to DataFrameWriter

```Scala
      val udf1 = udf({(x: Int, y: Int) => x + y})
      val df = spark.range(0, 3).toDF("a")
        .withColumn("b", udf1($"a", udf1($"a", lit(10))))
      df.cache()
      df.write.saveAsTable("t")
```
Cache is not being used because the plans do not match with the cached plan. 
This is a regression caused by the changes we made in AnalysisBarrier, since 
not all the Analyzer rules are idempotent.

Added a test.

Also found a bug in the DSV1 write path. This is not a regression. Thus, opened 
a separate JIRA https://issues.apache.org/jira/browse/SPARK-24869

Author: Xiao Li <gatorsm...@gmail.com>

Closes #21821 from gatorsmile/testMaster22.

(cherry picked from commit d2e7deb59f641e93778b763d5396f73d38f9a785)
Signed-off-by: Xiao Li <gatorsm...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fa552c3c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fa552c3c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fa552c3c

Branch: refs/heads/branch-2.3
Commit: fa552c3c1102404fe98c72a5b83cffbc5ba41df3
Parents: 740606e
Author: Xiao Li <gatorsm...@gmail.com>
Authored: Wed Jul 25 17:22:37 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Wed Jul 25 17:24:32 2018 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/DataFrameWriter.scala  | 10 +++--
 .../spark/sql/execution/command/ddl.scala       |  7 ++--
 .../scala/org/apache/spark/sql/UDFSuite.scala   | 42 +++++++++++++++++++-
 3 files changed, 51 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fa552c3c/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index ed7a910..6c9fb52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -254,7 +254,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) 
{
           val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, 
options)
           if (writer.isPresent) {
             runCommand(df.sparkSession, "save") {
-              WriteToDataSourceV2(writer.get(), df.logicalPlan)
+              WriteToDataSourceV2(writer.get(), df.planWithBarrier)
             }
           }
 
@@ -275,7 +275,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) 
{
         sparkSession = df.sparkSession,
         className = source,
         partitionColumns = partitioningColumns.getOrElse(Nil),
-        options = extraOptions.toMap).planForWriting(mode, 
AnalysisBarrier(df.logicalPlan))
+        options = extraOptions.toMap).planForWriting(mode, df.planWithBarrier)
     }
   }
 
@@ -323,7 +323,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) 
{
       InsertIntoTable(
         table = UnresolvedRelation(tableIdent),
         partition = Map.empty[String, Option[String]],
-        query = df.logicalPlan,
+        query = df.planWithBarrier,
         overwrite = mode == SaveMode.Overwrite,
         ifPartitionNotExists = false)
     }
@@ -455,7 +455,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) 
{
       partitionColumnNames = partitioningColumns.getOrElse(Nil),
       bucketSpec = getBucketSpec)
 
-    runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, 
Some(df.logicalPlan)))
+    runCommand(df.sparkSession, "saveAsTable") {
+      CreateTable(tableDesc, mode, Some(df.planWithBarrier))
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/fa552c3c/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 0f4831b..28313f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -29,7 +29,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
 
 import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
 import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver}
+import org.apache.spark.sql.catalyst.analysis.{EliminateBarriers, 
NoSuchTableException, Resolver}
 import org.apache.spark.sql.catalyst.catalog._
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference}
@@ -889,8 +889,9 @@ object DDLUtils {
    * Throws exception if outputPath tries to overwrite inputpath.
    */
   def verifyNotReadPath(query: LogicalPlan, outputPath: Path) : Unit = {
-    val inputPaths = query.collect {
-      case LogicalRelation(r: HadoopFsRelation, _, _, _) => 
r.location.rootPaths
+    val inputPaths = EliminateBarriers(query).collect {
+      case LogicalRelation(r: HadoopFsRelation, _, _, _) =>
+        r.location.rootPaths
     }.flatten
 
     if (inputPaths.contains(outputPath)) {

http://git-wip-us.apache.org/repos/asf/spark/blob/fa552c3c/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 57bdec3..6f3937c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -19,11 +19,16 @@ package org.apache.spark.sql
 
 import org.apache.spark.sql.api.java._
 import org.apache.spark.sql.catalyst.plans.logical.Project
-import org.apache.spark.sql.execution.command.ExplainCommand
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.columnar.InMemoryRelation
+import 
org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, 
ExplainCommand}
+import 
org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand
 import org.apache.spark.sql.functions.{lit, udf}
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.test.SQLTestData._
 import org.apache.spark.sql.types.{DataTypes, DoubleType}
+import org.apache.spark.sql.util.QueryExecutionListener
+
 
 private case class FunctionResult(f1: String, f2: String)
 
@@ -305,6 +310,41 @@ class UDFSuite extends QueryTest with SharedSQLContext {
       .contains(s"UDF:$udf1Name(UDF:$udf2Name(1))"))
   }
 
+  test("cached Data should be used in the write path") {
+    withTable("t") {
+      withTempPath { path =>
+        var numTotalCachedHit = 0
+        val listener = new QueryExecutionListener {
+          override def onFailure(f: String, qe: QueryExecution, e: Exception): 
Unit = {}
+
+          override def onSuccess(funcName: String, qe: QueryExecution, 
duration: Long): Unit = {
+            qe.withCachedData match {
+              case c: CreateDataSourceTableAsSelectCommand
+                  if c.query.isInstanceOf[InMemoryRelation] =>
+                numTotalCachedHit += 1
+              case i: InsertIntoHadoopFsRelationCommand
+                  if i.query.isInstanceOf[InMemoryRelation] =>
+                numTotalCachedHit += 1
+              case _ =>
+            }
+          }
+        }
+        spark.listenerManager.register(listener)
+
+        val udf1 = udf({ (x: Int, y: Int) => x + y })
+        val df = spark.range(0, 3).toDF("a")
+          .withColumn("b", udf1($"a", lit(10)))
+        df.cache()
+        df.write.saveAsTable("t")
+        assert(numTotalCachedHit == 1, "expected to be cached in saveAsTable")
+        df.write.insertInto("t")
+        assert(numTotalCachedHit == 2, "expected to be cached in insertInto")
+        df.write.save(path.getCanonicalPath)
+        assert(numTotalCachedHit == 3, "expected to be cached in save for 
native")
+      }
+    }
+  }
+
   test("SPARK-24891 Fix HandleNullInputsForUDF rule") {
     val udf1 = udf({(x: Int, y: Int) => x + y})
     val df = spark.range(0, 3).toDF("a")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to