You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/12/01 19:39:03 UTC

spark git commit: [SPARK-11905][SQL] Support Persist/Cache and Unpersist in Dataset APIs

Repository: spark
Updated Branches:
  refs/heads/master fd95eeaf4 -> 0a7bca2da


[SPARK-11905][SQL] Support Persist/Cache and Unpersist in Dataset APIs

Persist and Unpersist exist in both RDD and Dataframe APIs. I think they are still very critical in Dataset APIs. Not sure if my understanding is correct? If so, could you help me check if the implementation is acceptable?

Please provide your opinions. marmbrus rxin cloud-fan

Thank you very much!

Author: gatorsmile <ga...@gmail.com>
Author: xiaoli <li...@gmail.com>
Author: Xiao Li <xi...@Xiaos-MacBook-Pro.local>

Closes #9889 from gatorsmile/persistDS.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0a7bca2d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0a7bca2d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0a7bca2d

Branch: refs/heads/master
Commit: 0a7bca2da04aefff16f2513ec27a92e69ceb77f6
Parents: fd95eea
Author: gatorsmile <ga...@gmail.com>
Authored: Tue Dec 1 10:38:59 2015 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Dec 1 10:38:59 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/DataFrame.scala  |  9 +++
 .../scala/org/apache/spark/sql/Dataset.scala    | 50 +++++++++++-
 .../scala/org/apache/spark/sql/SQLContext.scala |  9 +++
 .../spark/sql/execution/CacheManager.scala      | 27 +++----
 .../apache/spark/sql/DatasetCacheSuite.scala    | 80 ++++++++++++++++++++
 .../scala/org/apache/spark/sql/QueryTest.scala  |  5 +-
 6 files changed, 162 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0a7bca2d/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 6197f10..eb87003 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1584,6 +1584,7 @@ class DataFrame private[sql](
   def distinct(): DataFrame = dropDuplicates()
 
   /**
+   * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`).
    * @group basic
    * @since 1.3.0
    */
@@ -1593,12 +1594,17 @@ class DataFrame private[sql](
   }
 
   /**
+   * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`).
    * @group basic
    * @since 1.3.0
    */
   def cache(): this.type = persist()
 
   /**
+   * Persist this [[DataFrame]] with the given storage level.
+   * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`,
+   *                 `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`,
+   *                 `MEMORY_AND_DISK_2`, etc.
    * @group basic
    * @since 1.3.0
    */
@@ -1608,6 +1614,8 @@ class DataFrame private[sql](
   }
 
   /**
+   * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk.
+   * @param blocking Whether to block until all blocks are deleted.
    * @group basic
    * @since 1.3.0
    */
@@ -1617,6 +1625,7 @@ class DataFrame private[sql](
   }
 
   /**
+   * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk.
    * @group basic
    * @since 1.3.0
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/0a7bca2d/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index c357f88..d6bb1d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.{Queryable, QueryExecution}
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
 
 /**
@@ -565,7 +566,7 @@ class Dataset[T] private[sql](
    * combined.
    *
    * Note that, this function is not a typical set union operation, in that it does not eliminate
-   * duplicate items.  As such, it is analagous to `UNION ALL` in SQL.
+   * duplicate items.  As such, it is analogous to `UNION ALL` in SQL.
    * @since 1.6.0
    */
   def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union)
@@ -618,7 +619,6 @@ class Dataset[T] private[sql](
       case _ => Alias(CreateStruct(rightOutput), "_2")()
     }
 
-
     implicit val tuple2Encoder: Encoder[(T, U)] =
       ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
     withPlan[(T, U)](other) { (left, right) =>
@@ -697,11 +697,55 @@ class Dataset[T] private[sql](
    */
   def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
 
+  /**
+    * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`).
+    * @since 1.6.0
+    */
+  def persist(): this.type = {
+    sqlContext.cacheManager.cacheQuery(this)
+    this
+  }
+
+  /**
+    * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`).
+    * @since 1.6.0
+    */
+  def cache(): this.type = persist()
+
+  /**
+    * Persist this [[Dataset]] with the given storage level.
+    * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`,
+    *                 `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`,
+    *                 `MEMORY_AND_DISK_2`, etc.
+    * @group basic
+    * @since 1.6.0
+    */
+  def persist(newLevel: StorageLevel): this.type = {
+    sqlContext.cacheManager.cacheQuery(this, None, newLevel)
+    this
+  }
+
+  /**
+    * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk.
+    * @param blocking Whether to block until all blocks are deleted.
+    * @since 1.6.0
+    */
+  def unpersist(blocking: Boolean): this.type = {
+    sqlContext.cacheManager.tryUncacheQuery(this, blocking)
+    this
+  }
+
+  /**
+    * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk.
+    * @since 1.6.0
+    */
+  def unpersist(): this.type = unpersist(blocking = false)
+
   /* ******************** *
    *  Internal Functions  *
    * ******************** */
 
-  private[sql] def logicalPlan = queryExecution.analyzed
+  private[sql] def logicalPlan: LogicalPlan = queryExecution.analyzed
 
   private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
     new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder)

http://git-wip-us.apache.org/repos/asf/spark/blob/0a7bca2d/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 9cc65de..4e26250 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -339,6 +339,15 @@ class SQLContext private[sql](
   }
 
   /**
+    * Returns true if the [[Queryable]] is currently cached in-memory.
+    * @group cachemgmt
+    * @since 1.3.0
+    */
+  private[sql] def isCached(qName: Queryable): Boolean = {
+    cacheManager.lookupCachedData(qName).nonEmpty
+  }
+
+  /**
    * Caches the specified table in-memory.
    * @group cachemgmt
    * @since 1.3.0

http://git-wip-us.apache.org/repos/asf/spark/blob/0a7bca2d/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 293fcfe..50f6562 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
 import java.util.concurrent.locks.ReentrantReadWriteLock
 
 import org.apache.spark.Logging
-import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
 import org.apache.spark.storage.StorageLevel
@@ -75,12 +74,12 @@ private[sql] class CacheManager extends Logging {
   }
 
   /**
-   * Caches the data produced by the logical representation of the given [[DataFrame]]. Unlike
-   * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing
-   * the in-memory columnar representation of the underlying table is expensive.
+   * Caches the data produced by the logical representation of the given [[Queryable]].
+   * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because
+   * recomputing the in-memory columnar representation of the underlying table is expensive.
    */
   private[sql] def cacheQuery(
-      query: DataFrame,
+      query: Queryable,
       tableName: Option[String] = None,
       storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
     val planToCache = query.queryExecution.analyzed
@@ -95,13 +94,13 @@ private[sql] class CacheManager extends Logging {
             sqlContext.conf.useCompression,
             sqlContext.conf.columnBatchSize,
             storageLevel,
-            sqlContext.executePlan(query.logicalPlan).executedPlan,
+            sqlContext.executePlan(planToCache).executedPlan,
             tableName))
     }
   }
 
-  /** Removes the data for the given [[DataFrame]] from the cache */
-  private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock {
+  /** Removes the data for the given [[Queryable]] from the cache */
+  private[sql] def uncacheQuery(query: Queryable, blocking: Boolean = true): Unit = writeLock {
     val planToCache = query.queryExecution.analyzed
     val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
     require(dataIndex >= 0, s"Table $query is not cached.")
@@ -109,9 +108,11 @@ private[sql] class CacheManager extends Logging {
     cachedData.remove(dataIndex)
   }
 
-  /** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */
+  /** Tries to remove the data for the given [[Queryable]] from the cache
+    * if it's cached
+    */
   private[sql] def tryUncacheQuery(
-      query: DataFrame,
+      query: Queryable,
       blocking: Boolean = true): Boolean = writeLock {
     val planToCache = query.queryExecution.analyzed
     val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
@@ -123,12 +124,12 @@ private[sql] class CacheManager extends Logging {
     found
   }
 
-  /** Optionally returns cached data for the given [[DataFrame]] */
-  private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock {
+  /** Optionally returns cached data for the given [[Queryable]] */
+  private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock {
     lookupCachedData(query.queryExecution.analyzed)
   }
 
-  /** Optionally returns cached data for the given LogicalPlan. */
+  /** Optionally returns cached data for the given [[LogicalPlan]]. */
   private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
     cachedData.find(cd => plan.sameResult(cd.plan))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0a7bca2d/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
new file mode 100644
index 0000000..3a283a4
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.language.postfixOps
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+
+class DatasetCacheSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
+
+  test("persist and unpersist") {
+    val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int])
+    val cached = ds.cache()
+    // count triggers the caching action. It should not throw.
+    cached.count()
+    // Make sure, the Dataset is indeed cached.
+    assertCached(cached)
+    // Check result.
+    checkAnswer(
+      cached,
+      2, 3, 4)
+    // Drop the cache.
+    cached.unpersist()
+    assert(!sqlContext.isCached(cached), "The Dataset should not be cached.")
+  }
+
+  test("persist and then rebind right encoder when join 2 datasets") {
+    val ds1 = Seq("1", "2").toDS().as("a")
+    val ds2 = Seq(2, 3).toDS().as("b")
+
+    ds1.persist()
+    assertCached(ds1)
+    ds2.persist()
+    assertCached(ds2)
+
+    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
+    checkAnswer(joined, ("2", 2))
+    assertCached(joined, 2)
+
+    ds1.unpersist()
+    assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.")
+    ds2.unpersist()
+    assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.")
+  }
+
+  test("persist and then groupBy columns asKey, map") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    val grouped = ds.groupBy($"_1").keyAs[String]
+    val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
+    agged.persist()
+
+    checkAnswer(
+      agged.filter(_._1 == "b"),
+      ("b", 3))
+    assertCached(agged.filter(_._1 == "b"))
+
+    ds.unpersist()
+    assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.")
+    agged.unpersist()
+    assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0a7bca2d/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 8f476dd..bc22fb8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
+import org.apache.spark.sql.execution.Queryable
 
 abstract class QueryTest extends PlanTest {
 
@@ -163,9 +164,9 @@ abstract class QueryTest extends PlanTest {
   }
 
   /**
-   * Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
+   * Asserts that a given [[Queryable]] will be executed using the given number of cached results.
    */
-  def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
+  def assertCached(query: Queryable, numCachedTables: Int = 1): Unit = {
     val planWithCaching = query.queryExecution.withCachedData
     val cachedData = planWithCaching collect {
       case cached: InMemoryRelation => cached


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org