Skip to content
Snippets Groups Projects
Commit 4262fb0d authored by Herman van Hovell's avatar Herman van Hovell Committed by Wenchen Fan
Browse files

[SPARK-19070] Clean-up dataset actions

## What changes were proposed in this pull request?
Dataset actions currently spin off a new `Dataframe` only to track query execution. This PR simplifies this code path by using the `Dataset.queryExecution` directly. This PR also merges the typed and untyped action evaluation paths.

## How was this patch tested?
Existing tests.

Author: Herman van Hovell <hvanhovell@databricks.com>

Closes #16466 from hvanhovell/SPARK-19070.
parent a1e40b1f
No related branches found
No related tags found
No related merge requests found
...@@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.plans._ ...@@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.catalyst.util.usePrettyExpression
import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution} import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView}
import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.python.EvaluatePython
...@@ -2096,9 +2096,7 @@ class Dataset[T] private[sql]( ...@@ -2096,9 +2096,7 @@ class Dataset[T] private[sql](
* @group action * @group action
* @since 1.6.0 * @since 1.6.0
*/ */
def head(n: Int): Array[T] = withTypedCallback("head", limit(n)) { df => def head(n: Int): Array[T] = withAction("head", limit(n).queryExecution)(collectFromPlan)
df.collect(needCallback = false)
}
/** /**
* Returns the first row. * Returns the first row.
...@@ -2325,7 +2323,7 @@ class Dataset[T] private[sql]( ...@@ -2325,7 +2323,7 @@ class Dataset[T] private[sql](
def takeAsList(n: Int): java.util.List[T] = java.util.Arrays.asList(take(n) : _*) def takeAsList(n: Int): java.util.List[T] = java.util.Arrays.asList(take(n) : _*)
/** /**
* Returns an array that contains all of [[Row]]s in this Dataset. * Returns an array that contains all rows in this Dataset.
* *
* Running collect requires moving all the data into the application's driver process, and * Running collect requires moving all the data into the application's driver process, and
* doing so on a very large dataset can crash the driver process with OutOfMemoryError. * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
...@@ -2335,10 +2333,10 @@ class Dataset[T] private[sql]( ...@@ -2335,10 +2333,10 @@ class Dataset[T] private[sql](
* @group action * @group action
* @since 1.6.0 * @since 1.6.0
*/ */
def collect(): Array[T] = collect(needCallback = true) def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan)
/** /**
* Returns a Java list that contains all of [[Row]]s in this Dataset. * Returns a Java list that contains all rows in this Dataset.
* *
* Running collect requires moving all the data into the application's driver process, and * Running collect requires moving all the data into the application's driver process, and
* doing so on a very large dataset can crash the driver process with OutOfMemoryError. * doing so on a very large dataset can crash the driver process with OutOfMemoryError.
...@@ -2346,27 +2344,13 @@ class Dataset[T] private[sql]( ...@@ -2346,27 +2344,13 @@ class Dataset[T] private[sql](
* @group action * @group action
* @since 1.6.0 * @since 1.6.0
*/ */
def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ => def collectAsList(): java.util.List[T] = withAction("collectAsList", queryExecution) { plan =>
withNewExecutionId { val values = collectFromPlan(plan)
val values = queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow) java.util.Arrays.asList(values : _*)
java.util.Arrays.asList(values : _*)
}
}
private def collect(needCallback: Boolean): Array[T] = {
def execute(): Array[T] = withNewExecutionId {
queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow)
}
if (needCallback) {
withCallback("collect", toDF())(_ => execute())
} else {
execute()
}
} }
/** /**
* Return an iterator that contains all of [[Row]]s in this Dataset. * Return an iterator that contains all rows in this Dataset.
* *
* The iterator will consume as much memory as the largest partition in this Dataset. * The iterator will consume as much memory as the largest partition in this Dataset.
* *
...@@ -2377,9 +2361,9 @@ class Dataset[T] private[sql]( ...@@ -2377,9 +2361,9 @@ class Dataset[T] private[sql](
* @group action * @group action
* @since 2.0.0 * @since 2.0.0
*/ */
def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ => def toLocalIterator(): java.util.Iterator[T] = {
withNewExecutionId { withAction("toLocalIterator", queryExecution) { plan =>
queryExecution.executedPlan.executeToIterator().map(boundEnc.fromRow).asJava plan.executeToIterator().map(boundEnc.fromRow).asJava
} }
} }
...@@ -2388,8 +2372,8 @@ class Dataset[T] private[sql]( ...@@ -2388,8 +2372,8 @@ class Dataset[T] private[sql](
* @group action * @group action
* @since 1.6.0 * @since 1.6.0
*/ */
def count(): Long = withCallback("count", groupBy().count()) { df => def count(): Long = withAction("count", groupBy().count().queryExecution) { plan =>
df.collect(needCallback = false).head.getLong(0) plan.executeCollect().head.getLong(0)
} }
/** /**
...@@ -2762,38 +2746,30 @@ class Dataset[T] private[sql]( ...@@ -2762,38 +2746,30 @@ class Dataset[T] private[sql](
* Wrap a Dataset action to track the QueryExecution and time cost, then report to the * Wrap a Dataset action to track the QueryExecution and time cost, then report to the
* user-registered callback functions. * user-registered callback functions.
*/ */
private def withCallback[U](name: String, df: DataFrame)(action: DataFrame => U) = { private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
try { try {
df.queryExecution.executedPlan.foreach { plan => qe.executedPlan.foreach { plan =>
plan.resetMetrics() plan.resetMetrics()
} }
val start = System.nanoTime() val start = System.nanoTime()
val result = action(df) val result = SQLExecution.withNewExecutionId(sparkSession, qe) {
action(qe.executedPlan)
}
val end = System.nanoTime() val end = System.nanoTime()
sparkSession.listenerManager.onSuccess(name, df.queryExecution, end - start) sparkSession.listenerManager.onSuccess(name, qe, end - start)
result result
} catch { } catch {
case e: Exception => case e: Exception =>
sparkSession.listenerManager.onFailure(name, df.queryExecution, e) sparkSession.listenerManager.onFailure(name, qe, e)
throw e throw e
} }
} }
private def withTypedCallback[A, B](name: String, ds: Dataset[A])(action: Dataset[A] => B) = { /**
try { * Collect all elements from a spark plan.
ds.queryExecution.executedPlan.foreach { plan => */
plan.resetMetrics() private def collectFromPlan(plan: SparkPlan): Array[T] = {
} plan.executeCollect().map(boundEnc.fromRow)
val start = System.nanoTime()
val result = action(ds)
val end = System.nanoTime()
sparkSession.listenerManager.onSuccess(name, ds.queryExecution, end - start)
result
} catch {
case e: Exception =>
sparkSession.listenerManager.onFailure(name, ds.queryExecution, e)
throw e
}
} }
private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment