From 8bfc3b7aac577e36aadc4fe6dee0665d0b2ae919 Mon Sep 17 00:00:00 2001 From: Cheng Lian <lian@databricks.com> Date: Mon, 31 Oct 2016 13:39:59 -0700 Subject: [PATCH] [SPARK-17972][SQL] Add Dataset.checkpoint() to truncate large query plans ## What changes were proposed in this pull request? ### Problem Iterative ML code may easily create query plans that grow exponentially. We found that query planning time also increases exponentially even when all the sub-plan trees are cached. The following snippet illustrates the problem: ``` scala (0 until 6).foldLeft(Seq(1, 2, 3).toDS) { (plan, iteration) => println(s"== Iteration $iteration ==") val time0 = System.currentTimeMillis() val joined = plan.join(plan, "value").join(plan, "value").join(plan, "value").join(plan, "value") joined.cache() println(s"Query planning takes ${System.currentTimeMillis() - time0} ms") joined.as[Int] } // == Iteration 0 == // Query planning takes 9 ms // == Iteration 1 == // Query planning takes 26 ms // == Iteration 2 == // Query planning takes 53 ms // == Iteration 3 == // Query planning takes 163 ms // == Iteration 4 == // Query planning takes 700 ms // == Iteration 5 == // Query planning takes 3418 ms ``` This is because when building a new Dataset, the new plan is always built upon `QueryExecution.analyzed`, which doesn't leverage existing cached plans. On the other hand, usually, doing caching every a few iterations may not be the right direction for this problem since caching is too memory consuming (imaging computing connected components over a graph with 50 billion nodes). What we really need here is to truncate both the query plan (to minimize query planning time) and the lineage of the underlying RDD (to avoid stack overflow). ### Changes introduced in this PR This PR tries to fix this issue by introducing a `checkpoint()` method into `Dataset[T]`, which does exactly the things described above. The following snippet, which is essentially the same as the one above but invokes `checkpoint()` instead of `cache()`, shows the micro benchmark result of this PR: One key point is that the checkpointed Dataset should preserve the origianl partitioning and ordering information of the original Dataset, so that we can avoid unnecessary shuffling (similar to reading from a pre-bucketed table). This is done by adding `outputPartitioning` and `outputOrdering` to `LogicalRDD` and `RDDScanExec`. ### Micro benchmark ``` scala spark.sparkContext.setCheckpointDir("/tmp/cp") (0 until 100).foldLeft(Seq(1, 2, 3).toDS) { (plan, iteration) => println(s"== Iteration $iteration ==") val time0 = System.currentTimeMillis() val cp = plan.checkpoint() cp.count() System.out.println(s"Checkpointing takes ${System.currentTimeMillis() - time0} ms") val time1 = System.currentTimeMillis() val joined = cp.join(cp, "value").join(cp, "value").join(cp, "value").join(cp, "value") val result = joined.as[Int] println(s"Query planning takes ${System.currentTimeMillis() - time1} ms") result } // == Iteration 0 == // Checkpointing takes 591 ms // Query planning takes 13 ms // == Iteration 1 == // Checkpointing takes 1605 ms // Query planning takes 16 ms // == Iteration 2 == // Checkpointing takes 782 ms // Query planning takes 8 ms // == Iteration 3 == // Checkpointing takes 729 ms // Query planning takes 10 ms // == Iteration 4 == // Checkpointing takes 734 ms // Query planning takes 9 ms // == Iteration 5 == // ... // == Iteration 50 == // Checkpointing takes 571 ms // Query planning takes 7 ms // == Iteration 51 == // Checkpointing takes 548 ms // Query planning takes 7 ms // == Iteration 52 == // Checkpointing takes 596 ms // Query planning takes 8 ms // == Iteration 53 == // Checkpointing takes 568 ms // Query planning takes 7 ms // ... ``` You may see that although checkpointing is more heavy weight an operation, it always takes roughly the same amount of time to perform both checkpointing and query planning. ### Open question mengxr mentioned that it would be more convenient if we can make `Dataset.checkpoint()` eager, i.e., always performs a `RDD.count()` after calling `RDD.checkpoint()`. Not quite sure whether this is a universal requirement. Maybe we can add a `eager: Boolean` argument for `Dataset.checkpoint()` to support that. ## How was this patch tested? Unit test added in `DatasetSuite`. Author: Cheng Lian <lian@databricks.com> Author: Yin Huai <yhuai@databricks.com> Closes #15651 from liancheng/ds-checkpoint. --- .../scala/org/apache/spark/sql/Dataset.scala | 57 +++++++++++++++- .../spark/sql/execution/ExistingRDD.scala | 37 ++++++++-- .../spark/sql/execution/SparkStrategies.scala | 7 +- .../org/apache/spark/sql/DatasetSuite.scala | 68 +++++++++++++++++++ 4 files changed, 157 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 286d8549bf..6e0a2471e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -40,13 +40,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.streaming.{DataStreamWriter, StreamingQuery} +import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -482,6 +483,58 @@ class Dataset[T] private[sql]( @InterfaceStability.Evolving def isStreaming: Boolean = logicalPlan.isStreaming + /** + * Returns a checkpointed version of this Dataset. + * + * @group basic + * @since 2.1.0 + */ + @Experimental + @InterfaceStability.Evolving + def checkpoint(): Dataset[T] = checkpoint(eager = true) + + /** + * Returns a checkpointed version of this Dataset. + * + * @param eager When true, materializes the underlying checkpointed RDD eagerly. + * + * @group basic + * @since 2.1.0 + */ + @Experimental + @InterfaceStability.Evolving + def checkpoint(eager: Boolean): Dataset[T] = { + val internalRdd = queryExecution.toRdd.map(_.copy()) + internalRdd.checkpoint() + + if (eager) { + internalRdd.count() + } + + val physicalPlan = queryExecution.executedPlan + + // Takes the first leaf partitioning whenever we see a `PartitioningCollection`. Otherwise the + // size of `PartitioningCollection` may grow exponentially for queries involving deep inner + // joins. + def firstLeafPartitioning(partitioning: Partitioning): Partitioning = { + partitioning match { + case p: PartitioningCollection => firstLeafPartitioning(p.partitionings.head) + case p => p + } + } + + val outputPartitioning = firstLeafPartitioning(physicalPlan.outputPartitioning) + + Dataset.ofRows( + sparkSession, + LogicalRDD( + logicalPlan.output, + internalRdd, + outputPartitioning, + physicalPlan.outputOrdering + )(sparkSession)).as[T] + } + /** * Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated, * and all cells will be aligned right. For example: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index d3a2222862..455fb5bfbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -130,17 +130,40 @@ case class ExternalRDDScanExec[T]( /** Logical plan node for scanning data from an RDD of InternalRow. */ case class LogicalRDD( output: Seq[Attribute], - rdd: RDD[InternalRow])(session: SparkSession) + rdd: RDD[InternalRow], + outputPartitioning: Partitioning = UnknownPartitioning(0), + outputOrdering: Seq[SortOrder] = Nil)(session: SparkSession) extends LeafNode with MultiInstanceRelation { override protected final def otherCopyArgs: Seq[AnyRef] = session :: Nil - override def newInstance(): LogicalRDD.this.type = - LogicalRDD(output.map(_.newInstance()), rdd)(session).asInstanceOf[this.type] + override def newInstance(): LogicalRDD.this.type = { + val rewrite = output.zip(output.map(_.newInstance())).toMap + + val rewrittenPartitioning = outputPartitioning match { + case p: Expression => + p.transform { + case e: Attribute => rewrite.getOrElse(e, e) + }.asInstanceOf[Partitioning] + + case p => p + } + + val rewrittenOrdering = outputOrdering.map(_.transform { + case e: Attribute => rewrite.getOrElse(e, e) + }.asInstanceOf[SortOrder]) + + LogicalRDD( + output.map(rewrite), + rdd, + rewrittenPartitioning, + rewrittenOrdering + )(session).asInstanceOf[this.type] + } override def sameResult(plan: LogicalPlan): Boolean = { plan.canonicalized match { - case LogicalRDD(_, otherRDD) => rdd.id == otherRDD.id + case LogicalRDD(_, otherRDD, _, _) => rdd.id == otherRDD.id case _ => false } } @@ -158,7 +181,9 @@ case class LogicalRDD( case class RDDScanExec( output: Seq[Attribute], rdd: RDD[InternalRow], - override val nodeName: String) extends LeafExecNode { + override val nodeName: String, + override val outputPartitioning: Partitioning = UnknownPartitioning(0), + override val outputOrdering: Seq[SortOrder] = Nil) extends LeafExecNode { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7cfae5ce28..5412aca95d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -32,8 +32,6 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} import org.apache.spark.sql.execution.streaming.{MemoryPlan, StreamingExecutionRelation, StreamingRelation, StreamingRelationExec} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.StreamingQuery /** * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting @@ -402,13 +400,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil - case r : logical.Range => + case r: logical.Range => execution.RangeExec(r) :: Nil case logical.RepartitionByExpression(expressions, child, nPartitions) => exchange.ShuffleExchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil - case LogicalRDD(output, rdd) => RDDScanExec(output, rdd, "ExistingRDD") :: Nil + case r: LogicalRDD => + RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index cc367acae2..55f0487805 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -22,8 +22,11 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.util.sideBySide +import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SortExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -919,6 +922,71 @@ class DatasetSuite extends QueryTest with SharedSQLContext { df.withColumn("b", expr("0")).as[ClassData] .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() }) } + + Seq(true, false).foreach { eager => + def testCheckpointing(testName: String)(f: => Unit): Unit = { + test(s"Dataset.checkpoint() - $testName (eager = $eager)") { + withTempDir { dir => + val originalCheckpointDir = spark.sparkContext.checkpointDir + + try { + spark.sparkContext.setCheckpointDir(dir.getCanonicalPath) + f + } finally { + // Since the original checkpointDir can be None, we need + // to set the variable directly. + spark.sparkContext.checkpointDir = originalCheckpointDir + } + } + } + } + + testCheckpointing("basic") { + val ds = spark.range(10).repartition('id % 2).filter('id > 5).orderBy('id.desc) + val cp = ds.checkpoint(eager) + + val logicalRDD = cp.logicalPlan match { + case plan: LogicalRDD => plan + case _ => + val treeString = cp.logicalPlan.treeString(verbose = true) + fail(s"Expecting a LogicalRDD, but got\n$treeString") + } + + val dsPhysicalPlan = ds.queryExecution.executedPlan + val cpPhysicalPlan = cp.queryExecution.executedPlan + + assertResult(dsPhysicalPlan.outputPartitioning) { logicalRDD.outputPartitioning } + assertResult(dsPhysicalPlan.outputOrdering) { logicalRDD.outputOrdering } + + assertResult(dsPhysicalPlan.outputPartitioning) { cpPhysicalPlan.outputPartitioning } + assertResult(dsPhysicalPlan.outputOrdering) { cpPhysicalPlan.outputOrdering } + + // For a lazy checkpoint() call, the first check also materializes the checkpoint. + checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) + + // Reads back from checkpointed data and check again. + checkDataset(cp, (9L to 6L by -1L).map(java.lang.Long.valueOf): _*) + } + + testCheckpointing("should preserve partitioning information") { + val ds = spark.range(10).repartition('id % 2) + val cp = ds.checkpoint(eager) + + val agg = cp.groupBy('id % 2).agg(count('id)) + + agg.queryExecution.executedPlan.collectFirst { + case ShuffleExchange(_, _: RDDScanExec, _) => + case BroadcastExchangeExec(_, _: RDDScanExec) => + }.foreach { _ => + fail( + "No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " + + "preserves partitioning information:\n\n" + agg.queryExecution + ) + } + + checkAnswer(agg, ds.groupBy('id % 2).agg(count('id))) + } + } } case class Generic[T](id: T, value: Double) -- GitLab