Repository: spark
Updated Branches:
  refs/heads/master 8ab8ef773 -> bac50aa37


[SPARK-24596][SQL] Non-cascading Cache Invalidation

## What changes were proposed in this pull request?

1. Add parameter 'cascade' in CacheManager.uncacheQuery(). Under 
'cascade=false' mode, only invalidate the current cache, and for other 
dependent caches, rebuild execution plan and reuse cached buffer.
2. Pass true/false from callers in different uncache scenarios:
- Drop tables and regular (persistent) views: regular mode
- Drop temporary views: non-cascading mode
- Modify table contents (INSERT/UPDATE/MERGE/DELETE): regular mode
- Call `DataSet.unpersist()`: non-cascading mode
- Call `Catalog.uncacheTable()`: follow the same convention as drop 
tables/view, which is, use non-cascading mode for temporary views and regular 
mode for the rest

Note that a regular (persistent) view is a database object just like a table, 
so after dropping a regular view (whether cached or not cached), any query 
referring to that view should no long be valid. Hence if a cached persistent 
view is dropped, we need to invalidate the all dependent caches so that 
exceptions will be thrown for any later reference. On the other hand, a 
temporary view is in fact equivalent to an unnamed DataSet, and dropping a 
temporary view should have no impact on queries referencing that view. Thus we 
should do non-cascading uncaching for temporary views, which also guarantees a 
consistent uncaching behavior between temporary views and unnamed DataSets.

## How was this patch tested?

New tests in CachedTableSuite and DatasetCacheSuite.

Author: Maryann Xue <maryann...@apache.org>

Closes #21594 from maryannxue/noncascading-cache.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bac50aa3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bac50aa3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bac50aa3

Branch: refs/heads/master
Commit: bac50aa37168a7612702a4503750a78ed5d59c78
Parents: 8ab8ef7
Author: Maryann Xue <maryann...@apache.org>
Authored: Mon Jun 25 07:17:30 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Mon Jun 25 07:17:30 2018 -0700

----------------------------------------------------------------------
 docs/sql-programming-guide.md                   |  1 +
 .../scala/org/apache/spark/sql/Dataset.scala    |  4 +-
 .../spark/sql/execution/CacheManager.scala      | 50 +++++++++++---
 .../execution/columnar/InMemoryRelation.scala   | 10 +++
 .../spark/sql/execution/command/ddl.scala       |  8 ++-
 .../spark/sql/execution/command/tables.scala    |  2 +-
 .../apache/spark/sql/internal/CatalogImpl.scala | 12 ++--
 .../org/apache/spark/sql/CachedTableSuite.scala | 66 +++++++++++++++++-
 .../apache/spark/sql/DatasetCacheSuite.scala    | 70 ++++++++++++++++++--
 9 files changed, 197 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bac50aa3/docs/sql-programming-guide.md
----------------------------------------------------------------------
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index d2db067..196b814 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1827,6 +1827,7 @@ working with timestamps in `pandas_udf`s to get the best 
performance, see
   - Since Spark 2.4, Spark compares a DATE type with a TIMESTAMP type after 
promotes both sides to TIMESTAMP. To set `false` to 
`spark.sql.hive.compareDateTimestampInTimestamp` restores the previous 
behavior. This option will be removed in Spark 3.0.
   - Since Spark 2.4, creating a managed table with nonempty location is not 
allowed. An exception is thrown when attempting to create a managed table with 
nonempty location. To set `true` to 
`spark.sql.allowCreatingManagedTableUsingNonemptyLocation` restores the 
previous behavior. This option will be removed in Spark 3.0.
   - Since Spark 2.4, the type coercion rules can automatically promote the 
argument types of the variadic SQL functions (e.g., IN/COALESCE) to the widest 
common type, no matter how the input arguments order. In prior Spark versions, 
the promotion could fail in some specific orders (e.g., TimestampType, 
IntegerType and StringType) and throw an exception.
+  - Since Spark 2.4, Spark has enabled non-cascading SQL cache invalidation in 
addition to the traditional cache invalidation mechanism. The non-cascading 
cache invalidation mechanism allows users to remove a cache without impacting 
its dependent caches. This new cache invalidation mechanism is used in 
scenarios where the data of the cache to be removed is still valid, e.g., 
calling unpersist() on a Dataset, or dropping a temporary view. This allows 
users to free up memory and keep the desired caches valid at the same time.
   - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` 
respect the timezone in the input timestamp string, which breaks the assumption 
that the input timestamp is in a specific timezone. Therefore, these 2 
functions can return unexpected results. In version 2.4 and later, this problem 
has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if 
the input timestamp string contains timezone. As an example, 
`from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 
01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 
00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return 
`2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care 
about this problem and want to retain the previous behaivor to keep their query 
unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. 
This option will be removed in Spark 3.0 and should only be used as a temporary 
w
 orkaround.
   - In version 2.3 and earlier, Spark converts Parquet Hive tables by default 
but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. 
This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 
'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 
2.4, Spark respects Parquet/ORC specific table properties while converting 
Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS 
PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy 
parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would 
be uncompressed parquet files.
   - Since Spark 2.0, Spark converts Parquet Hive tables by default for better 
performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. 
It means Spark uses its own ORC support by default instead of Hive SerDe. As an 
example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive 
SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC 
data source table and ORC vectorization would be applied. To set `false` to 
`spark.sql.hive.convertMetastoreOrc` restores the previous behavior.

http://git-wip-us.apache.org/repos/asf/spark/blob/bac50aa3/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index f552610..57f1e17 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2964,6 +2964,7 @@ class Dataset[T] private[sql](
 
   /**
    * Mark the Dataset as non-persistent, and remove all blocks for it from 
memory and disk.
+   * This will not un-persist any cached data that is built upon this Dataset.
    *
    * @param blocking Whether to block until all blocks are deleted.
    *
@@ -2971,12 +2972,13 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def unpersist(blocking: Boolean): this.type = {
-    sparkSession.sharedState.cacheManager.uncacheQuery(this, blocking)
+    sparkSession.sharedState.cacheManager.uncacheQuery(this, cascade = false, 
blocking)
     this
   }
 
   /**
    * Mark the Dataset as non-persistent, and remove all blocks for it from 
memory and disk.
+   * This will not un-persist any cached data that is built upon this Dataset.
    *
    * @group basic
    * @since 1.6.0

http://git-wip-us.apache.org/repos/asf/spark/blob/bac50aa3/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 2db7c02..39d9a95 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -105,24 +105,50 @@ class CacheManager extends Logging {
   }
 
   /**
-   * Un-cache all the cache entries that refer to the given plan.
+   * Un-cache the given plan or all the cache entries that refer to the given 
plan.
+   * @param query     The [[Dataset]] to be un-cached.
+   * @param cascade   If true, un-cache all the cache entries that refer to 
the given
+   *                  [[Dataset]]; otherwise un-cache the given [[Dataset]] 
only.
+   * @param blocking  Whether to block until all blocks are deleted.
    */
-  def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = 
writeLock {
-    uncacheQuery(query.sparkSession, query.logicalPlan, blocking)
+  def uncacheQuery(
+      query: Dataset[_],
+      cascade: Boolean,
+      blocking: Boolean = true): Unit = writeLock {
+    uncacheQuery(query.sparkSession, query.logicalPlan, cascade, blocking)
   }
 
   /**
-   * Un-cache all the cache entries that refer to the given plan.
+   * Un-cache the given plan or all the cache entries that refer to the given 
plan.
+   * @param spark     The Spark session.
+   * @param plan      The plan to be un-cached.
+   * @param cascade   If true, un-cache all the cache entries that refer to 
the given
+   *                  plan; otherwise un-cache the given plan only.
+   * @param blocking  Whether to block until all blocks are deleted.
    */
-  def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): 
Unit = writeLock {
+  def uncacheQuery(
+      spark: SparkSession,
+      plan: LogicalPlan,
+      cascade: Boolean,
+      blocking: Boolean): Unit = writeLock {
+    val shouldRemove: LogicalPlan => Boolean =
+      if (cascade) {
+        _.find(_.sameResult(plan)).isDefined
+      } else {
+        _.sameResult(plan)
+      }
     val it = cachedData.iterator()
     while (it.hasNext) {
       val cd = it.next()
-      if (cd.plan.find(_.sameResult(plan)).isDefined) {
+      if (shouldRemove(cd.plan)) {
         cd.cachedRepresentation.cacheBuilder.clearCache(blocking)
         it.remove()
       }
     }
+    // Re-compile dependent cached queries after removing the cached query.
+    if (!cascade) {
+      recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined, 
clearCache = false)
+    }
   }
 
   /**
@@ -132,20 +158,24 @@ class CacheManager extends Logging {
     recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined)
   }
 
-  private def recacheByCondition(spark: SparkSession, condition: LogicalPlan 
=> Boolean): Unit = {
+  private def recacheByCondition(
+      spark: SparkSession,
+      condition: LogicalPlan => Boolean,
+      clearCache: Boolean = true): Unit = {
     val it = cachedData.iterator()
     val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData]
     while (it.hasNext) {
       val cd = it.next()
       if (condition(cd.plan)) {
-        cd.cachedRepresentation.cacheBuilder.clearCache()
+        if (clearCache) {
+          cd.cachedRepresentation.cacheBuilder.clearCache()
+        }
         // Remove the cache entry before we create a new one, so that we can 
have a different
         // physical plan.
         it.remove()
         val plan = 
spark.sessionState.executePlan(AnalysisBarrier(cd.plan)).executedPlan
         val newCache = InMemoryRelation(
-          cacheBuilder = cd.cachedRepresentation
-            .cacheBuilder.copy(cachedPlan = plan)(_cachedColumnBuffers = null),
+          cacheBuilder = 
cd.cachedRepresentation.cacheBuilder.withCachedPlan(plan),
           logicalPlan = cd.plan)
         needToRecache += cd.copy(cachedRepresentation = newCache)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/bac50aa3/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index da35a47..7c8faec 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -74,6 +74,16 @@ case class CachedRDDBuilder(
     }
   }
 
+  def withCachedPlan(cachedPlan: SparkPlan): CachedRDDBuilder = {
+    new CachedRDDBuilder(
+      useCompression,
+      batchSize,
+      storageLevel,
+      cachedPlan = cachedPlan,
+      tableName
+    )(_cachedColumnBuffers)
+  }
+
   private def buildBuffers(): RDD[CachedBatch] = {
     val output = cachedPlan.output
     val cached = cachedPlan.execute().mapPartitionsInternal { rowIterator =>

http://git-wip-us.apache.org/repos/asf/spark/blob/bac50aa3/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index bf4d96f..04bf8c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -189,8 +189,9 @@ case class DropTableCommand(
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
     val catalog = sparkSession.sessionState.catalog
+    val isTempView = catalog.isTemporaryTable(tableName)
 
-    if (!catalog.isTemporaryTable(tableName) && 
catalog.tableExists(tableName)) {
+    if (!isTempView && catalog.tableExists(tableName)) {
       // If the command DROP VIEW is to drop a table or DROP TABLE is to drop 
a view
       // issue an exception.
       catalog.getTableMetadata(tableName).tableType match {
@@ -204,9 +205,10 @@ case class DropTableCommand(
       }
     }
 
-    if (catalog.isTemporaryTable(tableName) || catalog.tableExists(tableName)) 
{
+    if (isTempView || catalog.tableExists(tableName)) {
       try {
-        
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
+        sparkSession.sharedState.cacheManager.uncacheQuery(
+          sparkSession.table(tableName), cascade = !isTempView)
       } catch {
         case NonFatal(e) => log.warn(e.toString, e)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/bac50aa3/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 4474919..ec3961f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -493,7 +493,7 @@ case class TruncateTableCommand(
     spark.sessionState.refreshTable(tableName.unquotedString)
     // Also try to drop the contents of the table from the columnar cache
     try {
-      
spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier))
+      
spark.sharedState.cacheManager.uncacheQuery(spark.table(table.identifier), 
cascade = true)
     } catch {
       case NonFatal(e) =>
         log.warn(s"Exception when attempting to uncache table 
$tableIdentWithDB", e)

http://git-wip-us.apache.org/repos/asf/spark/blob/bac50aa3/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index 6ae307b..4698e8a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -364,7 +364,8 @@ class CatalogImpl(sparkSession: SparkSession) extends 
Catalog {
    */
   override def dropTempView(viewName: String): Boolean = {
     sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef =>
-      sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, 
viewDef, blocking = true)
+      sparkSession.sharedState.cacheManager.uncacheQuery(
+        sparkSession, viewDef, cascade = false, blocking = true)
       sessionCatalog.dropTempView(viewName)
     }
   }
@@ -379,7 +380,8 @@ class CatalogImpl(sparkSession: SparkSession) extends 
Catalog {
    */
   override def dropGlobalTempView(viewName: String): Boolean = {
     sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { 
viewDef =>
-      sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, 
viewDef, blocking = true)
+      sparkSession.sharedState.cacheManager.uncacheQuery(
+        sparkSession, viewDef, cascade = false, blocking = true)
       sessionCatalog.dropGlobalTempView(viewName)
     }
   }
@@ -438,7 +440,9 @@ class CatalogImpl(sparkSession: SparkSession) extends 
Catalog {
    * @since 2.0.0
    */
   override def uncacheTable(tableName: String): Unit = {
-    
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName))
+    val tableIdent = 
sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)
+    val cascade = !sessionCatalog.isTemporaryTable(tableIdent)
+    
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName),
 cascade)
   }
 
   /**
@@ -490,7 +494,7 @@ class CatalogImpl(sparkSession: SparkSession) extends 
Catalog {
     // cached version and make the new version cached lazily.
     if (isCached(table)) {
       // Uncache the logicalPlan.
-      sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = 
true)
+      sparkSession.sharedState.cacheManager.uncacheQuery(table, cascade = 
true, blocking = true)
       // Cache it again.
       sparkSession.sharedState.cacheManager.cacheQuery(table, 
Some(tableIdent.table))
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/bac50aa3/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 6982c22..60c73df 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -28,7 +28,6 @@ import 
org.apache.spark.sql.catalyst.expressions.SubqueryExpression
 import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
 import org.apache.spark.sql.execution.columnar._
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
-import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
 import org.apache.spark.storage.{RDDBlockId, StorageLevel}
@@ -801,4 +800,69 @@ class CachedTableSuite extends QueryTest with SQLTestUtils 
with SharedSQLContext
     }
     assert(cachedData.collect === Seq(1001))
   }
+
+  test("SPARK-24596 Non-cascading Cache Invalidation - uncache temporary 
view") {
+    withTempView("t1", "t2") {
+      sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1")
+      sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1")
+
+      assert(spark.catalog.isCached("t1"))
+      assert(spark.catalog.isCached("t2"))
+      sql("UNCACHE TABLE t1")
+      assert(!spark.catalog.isCached("t1"))
+      assert(spark.catalog.isCached("t2"))
+    }
+  }
+
+  test("SPARK-24596 Non-cascading Cache Invalidation - drop temporary view") {
+    withTempView("t1", "t2") {
+      sql("CACHE TABLE t1 AS SELECT * FROM testData WHERE key > 1")
+      sql("CACHE TABLE t2 as SELECT * FROM t1 WHERE value > 1")
+
+      assert(spark.catalog.isCached("t1"))
+      assert(spark.catalog.isCached("t2"))
+      sql("DROP VIEW t1")
+      assert(spark.catalog.isCached("t2"))
+    }
+  }
+
+  test("SPARK-24596 Non-cascading Cache Invalidation - drop persistent view") {
+    withTable("t") {
+      spark.range(1, 10).toDF("key").withColumn("value", 'key * 2)
+        .write.format("json").saveAsTable("t")
+      withView("t1") {
+        withTempView("t2") {
+          sql("CREATE VIEW t1 AS SELECT * FROM t WHERE key > 1")
+
+          sql("CACHE TABLE t1")
+          sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1")
+
+          assert(spark.catalog.isCached("t1"))
+          assert(spark.catalog.isCached("t2"))
+          sql("DROP VIEW t1")
+          assert(!spark.catalog.isCached("t2"))
+        }
+      }
+    }
+  }
+
+  test("SPARK-24596 Non-cascading Cache Invalidation - uncache table") {
+    withTable("t") {
+      spark.range(1, 10).toDF("key").withColumn("value", 'key * 2)
+        .write.format("json").saveAsTable("t")
+      withTempView("t1", "t2") {
+        sql("CACHE TABLE t")
+        sql("CACHE TABLE t1 AS SELECT * FROM t WHERE key > 1")
+        sql("CACHE TABLE t2 AS SELECT * FROM t1 WHERE value > 1")
+
+        assert(spark.catalog.isCached("t"))
+        assert(spark.catalog.isCached("t1"))
+        assert(spark.catalog.isCached("t2"))
+        sql("UNCACHE TABLE t")
+        assert(!spark.catalog.isCached("t"))
+        assert(!spark.catalog.isCached("t1"))
+        assert(!spark.catalog.isCached("t2"))
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bac50aa3/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
index c4f0563..5c6a021 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
@@ -29,6 +29,16 @@ import org.apache.spark.storage.StorageLevel
 class DatasetCacheSuite extends QueryTest with SharedSQLContext with 
TimeLimits {
   import testImplicits._
 
+  /**
+   * Asserts that a cached [[Dataset]] will be built using the given number of 
other cached results.
+   */
+  private def assertCacheDependency(df: DataFrame, numOfCachesDependedUpon: 
Int = 1): Unit = {
+    val plan = df.queryExecution.withCachedData
+    assert(plan.isInstanceOf[InMemoryRelation])
+    val internalPlan = 
plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan
+    assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).size == 
numOfCachesDependedUpon)
+  }
+
   test("get storage level") {
     val ds1 = Seq("1", "2").toDS().as("a")
     val ds2 = Seq(2, 3).toDS().as("b")
@@ -117,7 +127,7 @@ class DatasetCacheSuite extends QueryTest with 
SharedSQLContext with TimeLimits
   }
 
   test("cache UDF result correctly") {
-    val expensiveUDF = udf({x: Int => Thread.sleep(10000); x})
+    val expensiveUDF = udf({x: Int => Thread.sleep(5000); x})
     val df = spark.range(0, 10).toDF("a").withColumn("b", expensiveUDF($"a"))
     val df2 = df.agg(sum(df("b")))
 
@@ -126,7 +136,7 @@ class DatasetCacheSuite extends QueryTest with 
SharedSQLContext with TimeLimits
     assertCached(df2)
 
     // udf has been evaluated during caching, and thus should not be 
re-evaluated here
-    failAfter(5 seconds) {
+    failAfter(3 seconds) {
       df2.collect()
     }
 
@@ -143,9 +153,57 @@ class DatasetCacheSuite extends QueryTest with 
SharedSQLContext with TimeLimits
     df.count()
     df2.cache()
 
-    val plan = df2.queryExecution.withCachedData
-    assert(plan.isInstanceOf[InMemoryRelation])
-    val internalPlan = 
plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan
-    assert(internalPlan.find(_.isInstanceOf[InMemoryTableScanExec]).isDefined)
+    assertCacheDependency(df2)
+  }
+
+  test("SPARK-24596 Non-cascading Cache Invalidation") {
+    val df = Seq(("a", 1), ("b", 2)).toDF("s", "i")
+    val df2 = df.filter('i > 1)
+    val df3 = df.filter('i < 2)
+
+    df2.cache()
+    df.cache()
+    df.count()
+    df3.cache()
+
+    df.unpersist()
+
+    // df un-cached; df2 and df3's cache plan re-compiled
+    assert(df.storageLevel == StorageLevel.NONE)
+    assertCacheDependency(df2, 0)
+    assertCacheDependency(df3, 0)
+  }
+
+  test("SPARK-24596 Non-cascading Cache Invalidation - verify cached data 
reuse") {
+    val expensiveUDF = udf({ x: Int => Thread.sleep(5000); x })
+    val df = spark.range(0, 5).toDF("a")
+    val df1 = df.withColumn("b", expensiveUDF($"a"))
+    val df2 = df1.groupBy('a).agg(sum('b))
+    val df3 = df.agg(sum('a))
+
+    df1.cache()
+    df2.cache()
+    df2.collect()
+    df3.cache()
+
+    assertCacheDependency(df2)
+
+    df1.unpersist(blocking = true)
+
+    // df1 un-cached; df2's cache plan re-compiled
+    assert(df1.storageLevel == StorageLevel.NONE)
+    assertCacheDependency(df1.groupBy('a).agg(sum('b)), 0)
+
+    val df4 = df1.groupBy('a).agg(sum('b)).agg(sum("sum(b)"))
+    assertCached(df4)
+    // reuse loaded cache
+    failAfter(3 seconds) {
+      checkDataset(df4, Row(10))
+    }
+
+    val df5 = df.agg(sum('a)).filter($"sum(a)" > 1)
+    assertCached(df5)
+    // first time use, load cache
+    checkDataset(df5, Row(10))
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to