diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 6197f10813a3b6e75d592c69996874e49f948bae..eb8700369275e4961fad4e14863d104c5d0ad1d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1584,6 +1584,7 @@ class DataFrame private[sql](
   def distinct(): DataFrame = dropDuplicates()
 
   /**
+   * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`).
    * @group basic
    * @since 1.3.0
    */
@@ -1593,12 +1594,17 @@ class DataFrame private[sql](
   }
 
   /**
+   * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`).
    * @group basic
    * @since 1.3.0
    */
   def cache(): this.type = persist()
 
   /**
+   * Persist this [[DataFrame]] with the given storage level.
+   * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`,
+   *                 `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`,
+   *                 `MEMORY_AND_DISK_2`, etc.
    * @group basic
    * @since 1.3.0
    */
@@ -1608,6 +1614,8 @@ class DataFrame private[sql](
   }
 
   /**
+   * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk.
+   * @param blocking Whether to block until all blocks are deleted.
    * @group basic
    * @since 1.3.0
    */
@@ -1617,6 +1625,7 @@ class DataFrame private[sql](
   }
 
   /**
+   * Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk.
    * @group basic
    * @since 1.3.0
    */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index c357f88a94dd0ce3593ba67bfa05a7e3c9fedaa3..d6bb1d2ad8e50fa340aa9eb8c400971248b649e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.{Queryable, QueryExecution}
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
 
 /**
@@ -565,7 +566,7 @@ class Dataset[T] private[sql](
    * combined.
    *
    * Note that, this function is not a typical set union operation, in that it does not eliminate
-   * duplicate items.  As such, it is analagous to `UNION ALL` in SQL.
+   * duplicate items.  As such, it is analogous to `UNION ALL` in SQL.
    * @since 1.6.0
    */
   def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union)
@@ -618,7 +619,6 @@ class Dataset[T] private[sql](
       case _ => Alias(CreateStruct(rightOutput), "_2")()
     }
 
-
     implicit val tuple2Encoder: Encoder[(T, U)] =
       ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
     withPlan[(T, U)](other) { (left, right) =>
@@ -697,11 +697,55 @@ class Dataset[T] private[sql](
    */
   def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
 
+  /**
+    * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`).
+    * @since 1.6.0
+    */
+  def persist(): this.type = {
+    sqlContext.cacheManager.cacheQuery(this)
+    this
+  }
+
+  /**
+    * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`).
+    * @since 1.6.0
+    */
+  def cache(): this.type = persist()
+
+  /**
+    * Persist this [[Dataset]] with the given storage level.
+    * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`,
+    *                 `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`,
+    *                 `MEMORY_AND_DISK_2`, etc.
+    * @group basic
+    * @since 1.6.0
+    */
+  def persist(newLevel: StorageLevel): this.type = {
+    sqlContext.cacheManager.cacheQuery(this, None, newLevel)
+    this
+  }
+
+  /**
+    * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk.
+    * @param blocking Whether to block until all blocks are deleted.
+    * @since 1.6.0
+    */
+  def unpersist(blocking: Boolean): this.type = {
+    sqlContext.cacheManager.tryUncacheQuery(this, blocking)
+    this
+  }
+
+  /**
+    * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk.
+    * @since 1.6.0
+    */
+  def unpersist(): this.type = unpersist(blocking = false)
+
   /* ******************** *
    *  Internal Functions  *
    * ******************** */
 
-  private[sql] def logicalPlan = queryExecution.analyzed
+  private[sql] def logicalPlan: LogicalPlan = queryExecution.analyzed
 
   private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
     new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 9cc65de19180aca8b3a53d47e5c95c1363f00bf8..4e26250868374383254bdddaa2af9aa8401b8c7b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -338,6 +338,15 @@ class SQLContext private[sql](
     cacheManager.lookupCachedData(table(tableName)).nonEmpty
   }
 
+  /**
+    * Returns true if the [[Queryable]] is currently cached in-memory.
+    * @group cachemgmt
+    * @since 1.3.0
+    */
+  private[sql] def isCached(qName: Queryable): Boolean = {
+    cacheManager.lookupCachedData(qName).nonEmpty
+  }
+
   /**
    * Caches the specified table in-memory.
    * @group cachemgmt
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 293fcfe96e677f8a92a5cf9a20582097dc5592e3..50f6562815c21c2f6cad05e534ec7545e9783a55 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
 import java.util.concurrent.locks.ReentrantReadWriteLock
 
 import org.apache.spark.Logging
-import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
 import org.apache.spark.storage.StorageLevel
@@ -75,12 +74,12 @@ private[sql] class CacheManager extends Logging {
   }
 
   /**
-   * Caches the data produced by the logical representation of the given [[DataFrame]]. Unlike
-   * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing
-   * the in-memory columnar representation of the underlying table is expensive.
+   * Caches the data produced by the logical representation of the given [[Queryable]].
+   * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because
+   * recomputing the in-memory columnar representation of the underlying table is expensive.
    */
   private[sql] def cacheQuery(
-      query: DataFrame,
+      query: Queryable,
       tableName: Option[String] = None,
       storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
     val planToCache = query.queryExecution.analyzed
@@ -95,13 +94,13 @@ private[sql] class CacheManager extends Logging {
             sqlContext.conf.useCompression,
             sqlContext.conf.columnBatchSize,
             storageLevel,
-            sqlContext.executePlan(query.logicalPlan).executedPlan,
+            sqlContext.executePlan(planToCache).executedPlan,
             tableName))
     }
   }
 
-  /** Removes the data for the given [[DataFrame]] from the cache */
-  private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock {
+  /** Removes the data for the given [[Queryable]] from the cache */
+  private[sql] def uncacheQuery(query: Queryable, blocking: Boolean = true): Unit = writeLock {
     val planToCache = query.queryExecution.analyzed
     val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
     require(dataIndex >= 0, s"Table $query is not cached.")
@@ -109,9 +108,11 @@ private[sql] class CacheManager extends Logging {
     cachedData.remove(dataIndex)
   }
 
-  /** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */
+  /** Tries to remove the data for the given [[Queryable]] from the cache
+    * if it's cached
+    */
   private[sql] def tryUncacheQuery(
-      query: DataFrame,
+      query: Queryable,
       blocking: Boolean = true): Boolean = writeLock {
     val planToCache = query.queryExecution.analyzed
     val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
@@ -123,12 +124,12 @@ private[sql] class CacheManager extends Logging {
     found
   }
 
-  /** Optionally returns cached data for the given [[DataFrame]] */
-  private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock {
+  /** Optionally returns cached data for the given [[Queryable]] */
+  private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock {
     lookupCachedData(query.queryExecution.analyzed)
   }
 
-  /** Optionally returns cached data for the given LogicalPlan. */
+  /** Optionally returns cached data for the given [[LogicalPlan]]. */
   private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
     cachedData.find(cd => plan.sameResult(cd.plan))
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..3a283a4e1f610c2271768046fe363889054b748c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.language.postfixOps
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+
+class DatasetCacheSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
+
+  test("persist and unpersist") {
+    val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int])
+    val cached = ds.cache()
+    // count triggers the caching action. It should not throw.
+    cached.count()
+    // Make sure, the Dataset is indeed cached.
+    assertCached(cached)
+    // Check result.
+    checkAnswer(
+      cached,
+      2, 3, 4)
+    // Drop the cache.
+    cached.unpersist()
+    assert(!sqlContext.isCached(cached), "The Dataset should not be cached.")
+  }
+
+  test("persist and then rebind right encoder when join 2 datasets") {
+    val ds1 = Seq("1", "2").toDS().as("a")
+    val ds2 = Seq(2, 3).toDS().as("b")
+
+    ds1.persist()
+    assertCached(ds1)
+    ds2.persist()
+    assertCached(ds2)
+
+    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
+    checkAnswer(joined, ("2", 2))
+    assertCached(joined, 2)
+
+    ds1.unpersist()
+    assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.")
+    ds2.unpersist()
+    assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.")
+  }
+
+  test("persist and then groupBy columns asKey, map") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    val grouped = ds.groupBy($"_1").keyAs[String]
+    val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
+    agged.persist()
+
+    checkAnswer(
+      agged.filter(_._1 == "b"),
+      ("b", 3))
+    assertCached(agged.filter(_._1 == "b"))
+
+    ds.unpersist()
+    assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.")
+    agged.unpersist()
+    assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.")
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 8f476dd0f99b6de5916db707a26d2489d2282846..bc22fb8b7bdb423bf6f18d298f1a344e06ae1c9d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
+import org.apache.spark.sql.execution.Queryable
 
 abstract class QueryTest extends PlanTest {
 
@@ -163,9 +164,9 @@ abstract class QueryTest extends PlanTest {
   }
 
   /**
-   * Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
+   * Asserts that a given [[Queryable]] will be executed using the given number of cached results.
    */
-  def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
+  def assertCached(query: Queryable, numCachedTables: Int = 1): Unit = {
     val planWithCaching = query.queryExecution.withCachedData
     val cachedData = planWithCaching collect {
       case cached: InMemoryRelation => cached