diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 286d8549bfe27b4edd93c755bb62e6fc555261bc..6e0a2471e0fb520d8c64b9435fc0e6f5e13a32ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -40,13 +40,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.streaming.{DataStreamWriter, StreamingQuery} +import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -482,6 +483,58 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def isStreaming: Boolean = logicalPlan.isStreaming + /** + * Returns a checkpointed version of this Dataset. + * + * @group basic + * @since 2.1.0 + */ + @Experimental + @InterfaceStability.Evolving + def checkpoint(): Dataset[T] = checkpoint(eager = true) + + /** + * Returns a checkpointed version of this Dataset. + * + * @param eager When true, materializes the underlying checkpointed RDD eagerly. + * + * @group basic + * @since 2.1.0 + */ + @Experimental + @InterfaceStability.Evolving + def checkpoint(eager: Boolean): Dataset[T] = { + val internalRdd = queryExecution.toRdd.map(_.copy()) + internalRdd.checkpoint() + + if (eager) { + internalRdd.count() + } + + val physicalPlan = queryExecution.executedPlan + + // Takes the first leaf partitioning whenever we see a `PartitioningCollection`. Otherwise the + // size of `PartitioningCollection` may grow exponentially for queries involving deep inner + // joins. + def firstLeafPartitioning(partitioning: Partitioning): Partitioning = { + partitioning match { + case p: PartitioningCollection => firstLeafPartitioning(p.partitionings.head) + case p => p + } + } + + val outputPartitioning = firstLeafPartitioning(physicalPlan.outputPartitioning) + + Dataset.ofRows( + sparkSession, + LogicalRDD( + logicalPlan.output, + internalRdd, + outputPartitioning, + physicalPlan.outputOrdering + )(sparkSession)).as[T] + } + /** * Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated, * and all cells will be aligned right. For example: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index d3a22228623e1e7b02668d07e36a5b1ff0f4af57..455fb5bfbb6f772e75fccc05239392937467b9ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -130,17 +130,40 @@ case class ExternalRDDScanExec[T]( /** Logical plan node for scanning data from an RDD of InternalRow. */ case class LogicalRDD( output: Seq[Attribute], - rdd: RDD[InternalRow])(session: SparkSession) + rdd: RDD[InternalRow], + outputPartitioning: Partitioning = UnknownPartitioning(0), + outputOrdering: Seq[SortOrder] = Nil)(session: SparkSession) extends LeafNode with MultiInstanceRelation { override protected final def otherCopyArgs: Seq[AnyRef] = session :: Nil - override def newInstance(): LogicalRDD.this.type = - LogicalRDD(output.map(_.newInstance()), rdd)(session).asInstanceOf[this.type] + override def newInstance(): LogicalRDD.this.type = { + val rewrite = output.zip(output.map(_.newInstance())).toMap + + val rewrittenPartitioning = outputPartitioning match { + case p: Expression => + p.transform { + case e: Attribute => rewrite.getOrElse(e, e) + }.asInstanceOf[Partitioning] + + case p => p + } + + val rewrittenOrdering = outputOrdering.map(_.transform { + case e: Attribute => rewrite.getOrElse(e, e) + }.asInstanceOf[SortOrder]) + + LogicalRDD( + output.map(rewrite), + rdd, + rewrittenPartitioning, + rewrittenOrdering + )(session).asInstanceOf[this.type] + } override def sameResult(plan: LogicalPlan): Boolean = { plan.canonicalized match { - case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id + case LogicalRDD(_, otherRDD, _, _) => rdd.id == otherRDD.id case _ => false } } @@ -158,7 +181,9 @@ case class LogicalRDD( case class RDDScanExec( output: Seq[Attribute], rdd: RDD[InternalRow], - override val nodeName: String) extends LeafExecNode { + override val nodeName: String, + override val outputPartitioning: Partitioning = UnknownPartitioning(0), + override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7cfae5ce283bfa7e589fd0310b4b9ac44b62f33f..5412aca95dcf1974a3c118817c31ba3c041c4903 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -32,8 +32,6 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} import org.apache.spark.sql.execution.streaming.{MemoryPlan, StreamingExecutionRelation, StreamingRelation, StreamingRelationExec} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.StreamingQuery /** * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting @@ -402,13 +400,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil - case r : logical.Range => + case r: logical.Range => execution.RangeExec(r) :: Nil case logical.RepartitionByExpression(expressions, child, nPartitions) => exchange.ShuffleExchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil - case LogicalRDD(output, rdd) => RDDScanExec(output, rdd, "ExistingRDD") :: Nil + case r: LogicalRDD => + RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index cc367acae2ba4359a8af7690186979c8841d7f36..55f04878052aab114ded96848d27608a01a6a411 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -22,8 +22,11 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.util.sideBySide +import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -919,6 +922,71 @@ class DatasetSuite extends QueryTest with SharedSQLContext { df.withColumn("b", expr("0")).as[ClassData] .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() }) } + + Seq(true, false).foreach { eager => + def testCheckpointing(testName: String)(f: => Unit): Unit = { + test(s"Dataset.checkpoint() - $testName (eager = $eager)") { + withTempDir { dir => + val originalCheckpointDir = spark.sparkContext.checkpointDir + + try { + spark.sparkContext.setCheckpointDir(dir.getCanonicalPath) + f + } finally { + // Since the original checkpointDir can be None, we need + // to set the variable directly. + spark.sparkContext.checkpointDir = originalCheckpointDir + } + } + } + } + + testCheckpointing("basic") { + val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc) + val cp = ds.checkpoint(eager) + + val logicalRDD = cp.logicalPlan match { + case plan: LogicalRDD => plan + case _ => + val treeString = cp.logicalPlan.treeString(verbose = true) + fail(s"Expecting a LogicalRDD, but got\n$treeString") + } + + val dsPhysicalPlan = ds.queryExecution.executedPlan + val cpPhysicalPlan = cp.queryExecution.executedPlan + + assertResult(dsPhysicalPlan.outputPartitioning) { logicalRDD.outputPartitioning } + assertResult(dsPhysicalPlan.outputOrdering) { logicalRDD.outputOrdering } + + assertResult(dsPhysicalPlan.outputPartitioning) { cpPhysicalPlan.outputPartitioning } + assertResult(dsPhysicalPlan.outputOrdering) { cpPhysicalPlan.outputOrdering } + + // For a lazy checkpoint() call, the first check also materializes the checkpoint. + checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) + + // Reads back from checkpointed data and check again. + checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) + } + + testCheckpointing("should preserve partitioning information") { + val ds = spark.range(10).repartition('id % 2) + val cp = ds.checkpoint(eager) + + val agg = cp.groupBy('id % 2).agg(count('id)) + + agg.queryExecution.executedPlan.collectFirst { + case ShuffleExchange(_, _: RDDScanExec, _) => + case BroadcastExchangeExec(_, _: RDDScanExec) => + }.foreach { _ => + fail( + "No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " + + "preserves partitioning information:\n\n" + agg.queryExecution + ) + } + + checkAnswer(agg, ds.groupBy('id % 2).agg(count('id))) + } + } } case class Generic[T](id: T, value: Double)