Skip to content
Snippets Groups Projects
Commit 5d96a710 authored by gatorsmile's avatar gatorsmile Committed by Michael Armbrust
Browse files

[SPARK-12188][SQL] Code refactoring and comment correction in Dataset APIs

This PR contains the following updates:

- Created a new private variable `boundTEncoder` that can be shared by multiple functions, `RDD`, `select` and `collect`.
- Replaced all the `queryExecution.analyzed` by the function call `logicalPlan`
- A few API comments are using wrong class names (e.g., `DataFrame`) or parameter names (e.g., `n`)
- A few API descriptions are wrong. (e.g., `mapPartitions`)

marmbrus rxin cloud-fan Could you take a look and check if they are appropriate? Thank you!

Author: gatorsmile <gatorsmile@gmail.com>

Closes #10184 from gatorsmile/datasetClean.
parent c0b13d55
No related branches found
No related tags found
No related merge requests found
......@@ -67,15 +67,21 @@ class Dataset[T] private[sql](
tEncoder: Encoder[T]) extends Queryable with Serializable {
/**
* An unresolved version of the internal encoder for the type of this dataset. This one is marked
* implicit so that we can use it when constructing new [[Dataset]] objects that have the same
* object type (that will be possibly resolved to a different schema).
* An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
* marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
* same object type (that will be possibly resolved to a different schema).
*/
private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
unresolvedTEncoder.resolve(queryExecution.analyzed.output, OuterScopes.outerScopes)
unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)
/**
* The encoder where the expressions used to construct an object from an input row have been
* bound to the ordinals of the given schema.
*/
private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
private implicit def classTag = resolvedTEncoder.clsTag
......@@ -89,7 +95,7 @@ class Dataset[T] private[sql](
override def schema: StructType = resolvedTEncoder.schema
/**
* Prints the schema of the underlying [[DataFrame]] to the console in a nice tree format.
* Prints the schema of the underlying [[Dataset]] to the console in a nice tree format.
* @since 1.6.0
*/
override def printSchema(): Unit = toDF().printSchema()
......@@ -111,7 +117,7 @@ class Dataset[T] private[sql](
* ************* */
/**
* Returns a new `Dataset` where each record has been mapped on to the specified type. The
* Returns a new [[Dataset]] where each record has been mapped on to the specified type. The
* method used to map columns depend on the type of `U`:
* - When `U` is a class, fields for the class will be mapped to columns of the same name
* (case sensitivity is determined by `spark.sql.caseSensitive`)
......@@ -145,7 +151,7 @@ class Dataset[T] private[sql](
def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan)
/**
* Returns this Dataset.
* Returns this [[Dataset]].
* @since 1.6.0
*/
// This is declared with parentheses to prevent the Scala compiler from treating
......@@ -153,15 +159,12 @@ class Dataset[T] private[sql](
def toDS(): Dataset[T] = this
/**
* Converts this Dataset to an RDD.
* Converts this [[Dataset]] to an [[RDD]].
* @since 1.6.0
*/
def rdd: RDD[T] = {
val tEnc = resolvedTEncoder
val input = queryExecution.analyzed.output
queryExecution.toRdd.mapPartitions { iter =>
val bound = tEnc.bind(input)
iter.map(bound.fromRow)
iter.map(boundTEncoder.fromRow)
}
}
......@@ -189,7 +192,7 @@ class Dataset[T] private[sql](
def show(numRows: Int): Unit = show(numRows, truncate = true)
/**
* Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters
* Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters
* will be truncated, and all cells will be aligned right.
*
* @since 1.6.0
......@@ -197,7 +200,7 @@ class Dataset[T] private[sql](
def show(): Unit = show(20)
/**
* Displays the top 20 rows of [[DataFrame]] in a tabular form.
* Displays the top 20 rows of [[Dataset]] in a tabular form.
*
* @param truncate Whether truncate long strings. If true, strings more than 20 characters will
* be truncated and all cells will be aligned right
......@@ -207,7 +210,7 @@ class Dataset[T] private[sql](
def show(truncate: Boolean): Unit = show(20, truncate)
/**
* Displays the [[DataFrame]] in a tabular form. For example:
* Displays the [[Dataset]] in a tabular form. For example:
* {{{
* year month AVG('Adj Close) MAX('Adj Close)
* 1980 12 0.503218 0.595103
......@@ -291,7 +294,7 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
* @since 1.6.0
*/
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
......@@ -307,7 +310,7 @@ class Dataset[T] private[sql](
/**
* (Java-specific)
* Returns a new [[Dataset]] that contains the result of applying `func` to each element.
* Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
* @since 1.6.0
*/
def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
......@@ -341,28 +344,28 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
* Runs `func` on each element of this Dataset.
* Runs `func` on each element of this [[Dataset]].
* @since 1.6.0
*/
def foreach(func: T => Unit): Unit = rdd.foreach(func)
/**
* (Java-specific)
* Runs `func` on each element of this Dataset.
* Runs `func` on each element of this [[Dataset]].
* @since 1.6.0
*/
def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
/**
* (Scala-specific)
* Runs `func` on each partition of this Dataset.
* Runs `func` on each partition of this [[Dataset]].
* @since 1.6.0
*/
def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)
/**
* (Java-specific)
* Runs `func` on each partition of this Dataset.
* Runs `func` on each partition of this [[Dataset]].
* @since 1.6.0
*/
def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
......@@ -374,7 +377,7 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
* Reduces the elements of this Dataset using the specified binary function. The given function
* Reduces the elements of this [[Dataset]] using the specified binary function. The given `func`
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
......@@ -382,7 +385,7 @@ class Dataset[T] private[sql](
/**
* (Java-specific)
* Reduces the elements of this Dataset using the specified binary function. The given function
* Reduces the elements of this Dataset using the specified binary function. The given `func`
* must be commutative and associative or the result may be non-deterministic.
* @since 1.6.0
*/
......@@ -390,11 +393,11 @@ class Dataset[T] private[sql](
/**
* (Scala-specific)
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
* Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
* @since 1.6.0
*/
def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
val inputPlan = queryExecution.analyzed
val inputPlan = logicalPlan
val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan)
val executed = sqlContext.executePlan(withGroupingKey)
......@@ -429,18 +432,18 @@ class Dataset[T] private[sql](
/**
* (Java-specific)
* Returns a [[GroupedDataset]] where the data is grouped by the given key function.
* Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
* @since 1.6.0
*/
def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
groupBy(f.call(_))(encoder)
def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
groupBy(func.call(_))(encoder)
/* ****************** *
* Typed Relational *
* ****************** */
/**
* Selects a set of column based expressions.
* Returns a new [[DataFrame]] by selecting a set of column based expressions.
* {{{
* df.select($"colA", $"colB" + 1)
* }}}
......@@ -464,8 +467,8 @@ class Dataset[T] private[sql](
sqlContext,
Project(
c1.withInputType(
resolvedTEncoder.bind(queryExecution.analyzed.output),
queryExecution.analyzed.output).named :: Nil,
boundTEncoder,
logicalPlan.output).named :: Nil,
logicalPlan))
}
......@@ -477,7 +480,7 @@ class Dataset[T] private[sql](
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named)
columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
......@@ -654,7 +657,7 @@ class Dataset[T] private[sql](
* Returns an array that contains all the elements in this [[Dataset]].
*
* Running collect requires moving all the data into the application's driver process, and
* doing so on a very large dataset can crash the driver process with OutOfMemoryError.
* doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
* @since 1.6.0
......@@ -662,17 +665,14 @@ class Dataset[T] private[sql](
def collect(): Array[T] = {
// This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
// to convert the rows into objects of type T.
val tEnc = resolvedTEncoder
val input = queryExecution.analyzed.output
val bound = tEnc.bind(input)
queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow)
queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
}
/**
* Returns an array that contains all the elements in this [[Dataset]].
*
* Running collect requires moving all the data into the application's driver process, and
* doing so on a very large dataset can crash the driver process with OutOfMemoryError.
* doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
*
* For Java API, use [[collectAsList]].
* @since 1.6.0
......@@ -683,7 +683,7 @@ class Dataset[T] private[sql](
* Returns the first `num` elements of this [[Dataset]] as an array.
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
* a very large `num` can crash the driver process with OutOfMemoryError.
* @since 1.6.0
*/
def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
......@@ -692,7 +692,7 @@ class Dataset[T] private[sql](
* Returns the first `num` elements of this [[Dataset]] as an array.
*
* Running take requires moving data into the application's driver process, and doing so with
* a very large `n` can crash the driver process with OutOfMemoryError.
* a very large `num` can crash the driver process with OutOfMemoryError.
* @since 1.6.0
*/
def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment