diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 286d8549bfe27b4edd93c755bb62e6fc555261bc..6e0a2471e0fb520d8c64b9435fc0e6f5e13a32ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -40,13 +40,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
 import org.apache.spark.sql.catalyst.util.usePrettyExpression
 import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution}
 import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView}
-import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
 import org.apache.spark.sql.execution.python.EvaluatePython
-import org.apache.spark.sql.streaming.{DataStreamWriter, StreamingQuery}
+import org.apache.spark.sql.streaming.DataStreamWriter
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
@@ -482,6 +483,58 @@ class Dataset[T] private[sql](
   @InterfaceStability.Evolving
   def isStreaming: Boolean = logicalPlan.isStreaming
 
+  /**
+   * Returns a checkpointed version of this Dataset.
+   *
+   * @group basic
+   * @since 2.1.0
+   */
+  @Experimental
+  @InterfaceStability.Evolving
+  def checkpoint(): Dataset[T] = checkpoint(eager = true)
+
+  /**
+   * Returns a checkpointed version of this Dataset.
+   *
+   * @param eager When true, materializes the underlying checkpointed RDD eagerly.
+   *
+   * @group basic
+   * @since 2.1.0
+   */
+  @Experimental
+  @InterfaceStability.Evolving
+  def checkpoint(eager: Boolean): Dataset[T] = {
+    val internalRdd = queryExecution.toRdd.map(_.copy())
+    internalRdd.checkpoint()
+
+    if (eager) {
+      internalRdd.count()
+    }
+
+    val physicalPlan = queryExecution.executedPlan
+
+    // Takes the first leaf partitioning whenever we see a `PartitioningCollection`. Otherwise the
+    // size of `PartitioningCollection` may grow exponentially for queries involving deep inner
+    // joins.
+    def firstLeafPartitioning(partitioning: Partitioning): Partitioning = {
+      partitioning match {
+        case p: PartitioningCollection => firstLeafPartitioning(p.partitionings.head)
+        case p => p
+      }
+    }
+
+    val outputPartitioning = firstLeafPartitioning(physicalPlan.outputPartitioning)
+
+    Dataset.ofRows(
+      sparkSession,
+      LogicalRDD(
+        logicalPlan.output,
+        internalRdd,
+        outputPartitioning,
+        physicalPlan.outputOrdering
+      )(sparkSession)).as[T]
+  }
+
   /**
    * Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated,
    * and all cells will be aligned right. For example:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index d3a22228623e1e7b02668d07e36a5b1ff0f4af57..455fb5bfbb6f772e75fccc05239392937467b9ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.types.DataType
 import org.apache.spark.util.Utils
@@ -130,17 +130,40 @@ case class ExternalRDDScanExec[T](
 /** Logical plan node for scanning data from an RDD of InternalRow. */
 case class LogicalRDD(
     output: Seq[Attribute],
-    rdd: RDD[InternalRow])(session: SparkSession)
+    rdd: RDD[InternalRow],
+    outputPartitioning: Partitioning = UnknownPartitioning(0),
+    outputOrdering: Seq[SortOrder] = Nil)(session: SparkSession)
   extends LeafNode with MultiInstanceRelation {
 
   override protected final def otherCopyArgs: Seq[AnyRef] = session :: Nil
 
-  override def newInstance(): LogicalRDD.this.type =
-    LogicalRDD(output.map(_.newInstance()), rdd)(session).asInstanceOf[this.type]
+  override def newInstance(): LogicalRDD.this.type = {
+    val rewrite = output.zip(output.map(_.newInstance())).toMap
+
+    val rewrittenPartitioning = outputPartitioning match {
+      case p: Expression =>
+        p.transform {
+          case e: Attribute => rewrite.getOrElse(e, e)
+        }.asInstanceOf[Partitioning]
+
+      case p => p
+    }
+
+    val rewrittenOrdering = outputOrdering.map(_.transform {
+      case e: Attribute => rewrite.getOrElse(e, e)
+    }.asInstanceOf[SortOrder])
+
+    LogicalRDD(
+      output.map(rewrite),
+      rdd,
+      rewrittenPartitioning,
+      rewrittenOrdering
+    )(session).asInstanceOf[this.type]
+  }
 
   override def sameResult(plan: LogicalPlan): Boolean = {
     plan.canonicalized match {
-      case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id
+      case LogicalRDD(_, otherRDD, _, _) => rdd.id == otherRDD.id
       case _ => false
     }
   }
@@ -158,7 +181,9 @@ case class LogicalRDD(
 case class RDDScanExec(
     output: Seq[Attribute],
     rdd: RDD[InternalRow],
-    override val nodeName: String) extends LeafExecNode {
+    override val nodeName: String,
+    override val outputPartitioning: Partitioning = UnknownPartitioning(0),
+    override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode {
 
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 7cfae5ce283bfa7e589fd0310b4b9ac44b62f33f..5412aca95dcf1974a3c118817c31ba3c041c4903 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -32,8 +32,6 @@ import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.exchange.ShuffleExchange
 import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
 import org.apache.spark.sql.execution.streaming.{MemoryPlan, StreamingExecutionRelation, StreamingRelation, StreamingRelationExec}
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.StreamingQuery
 
 /**
  * Converts a logical plan into zero or more SparkPlans.  This API is exposed for experimenting
@@ -402,13 +400,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
           generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
       case logical.OneRowRelation =>
         execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil
-      case r : logical.Range =>
+      case r: logical.Range =>
         execution.RangeExec(r) :: Nil
       case logical.RepartitionByExpression(expressions, child, nPartitions) =>
         exchange.ShuffleExchange(HashPartitioning(
           expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
       case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
-      case LogicalRDD(output, rdd) => RDDScanExec(output, rdd, "ExistingRDD") :: Nil
+      case r: LogicalRDD =>
+        RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil
       case BroadcastHint(child) => planLater(child) :: Nil
       case _ => Nil
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index cc367acae2ba4359a8af7690186979c8841d7f36..55f04878052aab114ded96848d27608a01a6a411 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -22,8 +22,11 @@ import java.sql.{Date, Timestamp}
 
 import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
 import org.apache.spark.sql.catalyst.util.sideBySide
+import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange}
 import org.apache.spark.sql.execution.streaming.MemoryStream
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
 
@@ -919,6 +922,71 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       df.withColumn("b", expr("0")).as[ClassData]
         .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
   }
+
+  Seq(true, false).foreach { eager =>
+    def testCheckpointing(testName: String)(f: => Unit): Unit = {
+      test(s"Dataset.checkpoint() - $testName (eager = $eager)") {
+        withTempDir { dir =>
+          val originalCheckpointDir = spark.sparkContext.checkpointDir
+
+          try {
+            spark.sparkContext.setCheckpointDir(dir.getCanonicalPath)
+            f
+          } finally {
+            // Since the original checkpointDir can be None, we need
+            // to set the variable directly.
+            spark.sparkContext.checkpointDir = originalCheckpointDir
+          }
+        }
+      }
+    }
+
+    testCheckpointing("basic") {
+      val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc)
+      val cp = ds.checkpoint(eager)
+
+      val logicalRDD = cp.logicalPlan match {
+        case plan: LogicalRDD => plan
+        case _ =>
+          val treeString = cp.logicalPlan.treeString(verbose = true)
+          fail(s"Expecting a LogicalRDD, but got\n$treeString")
+      }
+
+      val dsPhysicalPlan = ds.queryExecution.executedPlan
+      val cpPhysicalPlan = cp.queryExecution.executedPlan
+
+      assertResult(dsPhysicalPlan.outputPartitioning) { logicalRDD.outputPartitioning }
+      assertResult(dsPhysicalPlan.outputOrdering) { logicalRDD.outputOrdering }
+
+      assertResult(dsPhysicalPlan.outputPartitioning) { cpPhysicalPlan.outputPartitioning }
+      assertResult(dsPhysicalPlan.outputOrdering) { cpPhysicalPlan.outputOrdering }
+
+      // For a lazy checkpoint() call, the first check also materializes the checkpoint.
+      checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
+
+      // Reads back from checkpointed data and check again.
+      checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*)
+    }
+
+    testCheckpointing("should preserve partitioning information") {
+      val ds = spark.range(10).repartition('id % 2)
+      val cp = ds.checkpoint(eager)
+
+      val agg = cp.groupBy('id % 2).agg(count('id))
+
+      agg.queryExecution.executedPlan.collectFirst {
+        case ShuffleExchange(_, _: RDDScanExec, _) =>
+        case BroadcastExchangeExec(_, _: RDDScanExec) =>
+      }.foreach { _ =>
+        fail(
+          "No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " +
+            "preserves partitioning information:\n\n" + agg.queryExecution
+        )
+      }
+
+      checkAnswer(agg, ds.groupBy('id % 2).agg(count('id)))
+    }
+  }
 }
 
 case class Generic[T](id: T, value: Double)