Repository: spark
Updated Branches:
  refs/heads/master da9aeb0fd -> 5aeb7384c


[SPARK-16063][SQL] Add storageLevel to Dataset

[SPARK-11905](https://issues.apache.org/jira/browse/SPARK-11905) added support 
for `persist`/`cache` for `Dataset`. However, there is no user-facing API to 
check if a `Dataset` is cached and if so what the storage level is. This PR 
adds `getStorageLevel` to `Dataset`, analogous to `RDD.getStorageLevel`.

Updated `DatasetCacheSuite`.

Author: Nick Pentreath <ni...@za.ibm.com>

Closes #13780 from MLnick/ds-storagelevel.

Signed-off-by: Michael Armbrust <mich...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5aeb7384
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5aeb7384
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5aeb7384

Branch: refs/heads/master
Commit: 5aeb7384c7aa5f487f031f9ae07d3f1653399d14
Parents: da9aeb0
Author: Nick Pentreath <ni...@za.ibm.com>
Authored: Fri Oct 14 15:07:32 2016 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Fri Oct 14 15:09:49 2016 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 36 ++++++++++++++++----
 .../scala/org/apache/spark/sql/Dataset.scala    | 12 +++++++
 .../apache/spark/sql/DatasetCacheSuite.scala    | 36 ++++++++++++++------
 3 files changed, 68 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5aeb7384/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index ce277eb..7606ac0 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -407,24 +407,48 @@ class DataFrame(object):
 
     @since(1.3)
     def cache(self):
-        """ Persists with the default storage level (C{MEMORY_ONLY}).
+        """Persists the :class:`DataFrame` with the default storage level 
(C{MEMORY_AND_DISK}).
+
+        .. note:: the default storage level has changed to C{MEMORY_AND_DISK} 
to match Scala in 2.0.
         """
         self.is_cached = True
         self._jdf.cache()
         return self
 
     @since(1.3)
-    def persist(self, storageLevel=StorageLevel.MEMORY_ONLY):
-        """Sets the storage level to persist its values across operations
-        after the first time it is computed. This can only be used to assign
-        a new storage level if the RDD does not have a storage level set yet.
-        If no storage level is specified defaults to (C{MEMORY_ONLY}).
+    def persist(self, storageLevel=StorageLevel.MEMORY_AND_DISK):
+        """Sets the storage level to persist the contents of the 
:class:`DataFrame` across
+        operations after the first time it is computed. This can only be used 
to assign
+        a new storage level if the :class:`DataFrame` does not have a storage 
level set yet.
+        If no storage level is specified defaults to (C{MEMORY_AND_DISK}).
+
+        .. note:: the default storage level has changed to C{MEMORY_AND_DISK} 
to match Scala in 2.0.
         """
         self.is_cached = True
         javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
         self._jdf.persist(javaStorageLevel)
         return self
 
+    @property
+    @since(2.1)
+    def storageLevel(self):
+        """Get the :class:`DataFrame`'s current storage level.
+
+        >>> df.storageLevel
+        StorageLevel(False, False, False, False, 1)
+        >>> df.cache().storageLevel
+        StorageLevel(True, True, False, True, 1)
+        >>> df2.persist(StorageLevel.DISK_ONLY_2).storageLevel
+        StorageLevel(True, False, False, False, 2)
+        """
+        java_storage_level = self._jdf.storageLevel()
+        storage_level = StorageLevel(java_storage_level.useDisk(),
+                                     java_storage_level.useMemory(),
+                                     java_storage_level.useOffHeap(),
+                                     java_storage_level.deserialized(),
+                                     java_storage_level.replication())
+        return storage_level
+
     @since(1.3)
     def unpersist(self, blocking=False):
         """Marks the :class:`DataFrame` as non-persistent, and remove all 
blocks for it from

http://git-wip-us.apache.org/repos/asf/spark/blob/5aeb7384/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index e59a483..70c9cf5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2402,6 +2402,18 @@ class Dataset[T] private[sql](
   }
 
   /**
+   * Get the Dataset's current storage level, or StorageLevel.NONE if not 
persisted.
+   *
+   * @group basic
+   * @since 2.1.0
+   */
+  def storageLevel: StorageLevel = {
+    sparkSession.sharedState.cacheManager.lookupCachedData(this).map { 
cachedData =>
+      cachedData.cachedRepresentation.storageLevel
+    }.getOrElse(StorageLevel.NONE)
+  }
+
+  /**
    * Mark the Dataset as non-persistent, and remove all blocks for it from 
memory and disk.
    *
    * @param blocking Whether to block until all blocks are deleted.

http://git-wip-us.apache.org/repos/asf/spark/blob/5aeb7384/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
index 8d5e964..e0561ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
@@ -19,11 +19,32 @@ package org.apache.spark.sql
 
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.storage.StorageLevel
 
 
 class DatasetCacheSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
 
+  test("get storage level") {
+    val ds1 = Seq("1", "2").toDS().as("a")
+    val ds2 = Seq(2, 3).toDS().as("b")
+
+    // default storage level
+    ds1.persist()
+    ds2.cache()
+    assert(ds1.storageLevel == StorageLevel.MEMORY_AND_DISK)
+    assert(ds2.storageLevel == StorageLevel.MEMORY_AND_DISK)
+    // unpersist
+    ds1.unpersist()
+    assert(ds1.storageLevel == StorageLevel.NONE)
+    // non-default storage level
+    ds1.persist(StorageLevel.MEMORY_ONLY_2)
+    assert(ds1.storageLevel == StorageLevel.MEMORY_ONLY_2)
+    // joined Dataset should not be persisted
+    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
+    assert(joined.storageLevel == StorageLevel.NONE)
+  }
+
   test("persist and unpersist") {
     val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS().select(expr("_2 + 
1").as[Int])
     val cached = ds.cache()
@@ -37,8 +58,7 @@ class DatasetCacheSuite extends QueryTest with 
SharedSQLContext {
       2, 3, 4)
     // Drop the cache.
     cached.unpersist()
-    assert(spark.sharedState.cacheManager.lookupCachedData(cached).isEmpty,
-      "The Dataset should not be cached.")
+    assert(cached.storageLevel == StorageLevel.NONE, "The Dataset should not 
be cached.")
   }
 
   test("persist and then rebind right encoder when join 2 datasets") {
@@ -55,11 +75,9 @@ class DatasetCacheSuite extends QueryTest with 
SharedSQLContext {
     assertCached(joined, 2)
 
     ds1.unpersist()
-    assert(spark.sharedState.cacheManager.lookupCachedData(ds1).isEmpty,
-      "The Dataset ds1 should not be cached.")
+    assert(ds1.storageLevel == StorageLevel.NONE, "The Dataset ds1 should not 
be cached.")
     ds2.unpersist()
-    assert(spark.sharedState.cacheManager.lookupCachedData(ds2).isEmpty,
-      "The Dataset ds2 should not be cached.")
+    assert(ds2.storageLevel == StorageLevel.NONE, "The Dataset ds2 should not 
be cached.")
   }
 
   test("persist and then groupBy columns asKey, map") {
@@ -74,10 +92,8 @@ class DatasetCacheSuite extends QueryTest with 
SharedSQLContext {
     assertCached(agged.filter(_._1 == "b"))
 
     ds.unpersist()
-    assert(spark.sharedState.cacheManager.lookupCachedData(ds).isEmpty,
-      "The Dataset ds should not be cached.")
+    assert(ds.storageLevel == StorageLevel.NONE, "The Dataset ds should not be 
cached.")
     agged.unpersist()
-    assert(spark.sharedState.cacheManager.lookupCachedData(agged).isEmpty,
-      "The Dataset agged should not be cached.")
+    assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should 
not be cached.")
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to