Skip to content
Snippets Groups Projects
Commit 363a476c authored by Michael Armbrust's avatar Michael Armbrust
Browse files

[SPARK-11528] [SQL] Typed aggregations for Datasets

This PR adds the ability to do typed SQL aggregations.  We will likely also want to provide an interface to allow users to do aggregations on objects, but this is deferred to another PR.

```scala
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
ds.groupBy(_._1).agg(sum("_2").as[Int]).collect()

res0: Array(("a", 30), ("b", 3), ("c", 1))
```

Author: Michael Armbrust <michael@databricks.com>

Closes #9499 from marmbrus/dataset-agg.
parent eec74ba8
No related branches found
No related tags found
No related merge requests found
...@@ -254,6 +254,10 @@ case class AttributeReference( ...@@ -254,6 +254,10 @@ case class AttributeReference(
} }
override def toString: String = s"$name#${exprId.id}$typeSuffix" override def toString: String = s"$name#${exprId.id}$typeSuffix"
// Since the expression id is not in the first constructor it is missing from the default
// tree string.
override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}"
} }
/** /**
......
...@@ -55,7 +55,7 @@ import org.apache.spark.sql.types.StructType ...@@ -55,7 +55,7 @@ import org.apache.spark.sql.types.StructType
* @since 1.6.0 * @since 1.6.0
*/ */
@Experimental @Experimental
class Dataset[T] private( class Dataset[T] private[sql](
@transient val sqlContext: SQLContext, @transient val sqlContext: SQLContext,
@transient val queryExecution: QueryExecution, @transient val queryExecution: QueryExecution,
unresolvedEncoder: Encoder[T]) extends Serializable { unresolvedEncoder: Encoder[T]) extends Serializable {
......
...@@ -17,16 +17,25 @@ ...@@ -17,16 +17,25 @@
package org.apache.spark.sql package org.apache.spark.sql
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder}
import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.QueryExecution
/** /**
* :: Experimental ::
* A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
* construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing * construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing
* [[Dataset]]. * [[Dataset]].
*
* COMPATIBILITY NOTE: Long term we plan to make [[GroupedDataset)]] extend `GroupedData`. However,
* making this change to the class hierarchy would break some function signatures. As such, this
* class should be considered a preview of the final API. Changes will be made to the interface
* after Spark 1.6.
*/ */
@Experimental
class GroupedDataset[K, T] private[sql]( class GroupedDataset[K, T] private[sql](
private val kEncoder: Encoder[K], private val kEncoder: Encoder[K],
private val tEncoder: Encoder[T], private val tEncoder: Encoder[T],
...@@ -35,7 +44,7 @@ class GroupedDataset[K, T] private[sql]( ...@@ -35,7 +44,7 @@ class GroupedDataset[K, T] private[sql](
private val groupingAttributes: Seq[Attribute]) extends Serializable { private val groupingAttributes: Seq[Attribute]) extends Serializable {
private implicit val kEnc = kEncoder match { private implicit val kEnc = kEncoder match {
case e: ExpressionEncoder[K] => e.resolve(groupingAttributes) case e: ExpressionEncoder[K] => e.unbind(groupingAttributes).resolve(groupingAttributes)
case other => case other =>
throw new UnsupportedOperationException("Only expression encoders are currently supported") throw new UnsupportedOperationException("Only expression encoders are currently supported")
} }
...@@ -46,9 +55,16 @@ class GroupedDataset[K, T] private[sql]( ...@@ -46,9 +55,16 @@ class GroupedDataset[K, T] private[sql](
throw new UnsupportedOperationException("Only expression encoders are currently supported") throw new UnsupportedOperationException("Only expression encoders are currently supported")
} }
/** Encoders for built in aggregations. */
private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
private def logicalPlan = queryExecution.analyzed private def logicalPlan = queryExecution.analyzed
private def sqlContext = queryExecution.sqlContext private def sqlContext = queryExecution.sqlContext
private def groupedData =
new GroupedData(
new DataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType)
/** /**
* Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified
* type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]].
...@@ -88,6 +104,79 @@ class GroupedDataset[K, T] private[sql]( ...@@ -88,6 +104,79 @@ class GroupedDataset[K, T] private[sql](
MapGroups(f, groupingAttributes, logicalPlan)) MapGroups(f, groupingAttributes, logicalPlan))
} }
// To ensure valid overloading.
protected def agg(expr: Column, exprs: Column*): DataFrame =
groupedData.agg(expr, exprs: _*)
/**
* Internal helper function for building typed aggregations that return tuples. For simplicity
* and code reuse, we do this without the help of the type system and then use helper functions
* that cast appropriately for the user facing interface.
* TODO: does not handle aggrecations that return nonflat results,
*/
protected def aggUntyped(columns: TypedColumn[_]*): Dataset[_] = {
val aliases = (groupingAttributes ++ columns.map(_.expr)).map {
case u: UnresolvedAttribute => UnresolvedAlias(u)
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
}
val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan)
val execution = new QueryExecution(sqlContext, unresolvedPlan)
val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]])
// Rebind the encoders to the nested schema that will be produced by the aggregation.
val encoders = (kEnc +: columnEncoders).zip(execution.analyzed.output).map {
case (e: ExpressionEncoder[_], a) if !e.flat =>
e.nested(a).resolve(execution.analyzed.output)
case (e, a) =>
e.unbind(a :: Nil).resolve(execution.analyzed.output)
}
new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
}
/**
* Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key
* and the result of computing this aggregation over all elements in the group.
*/
def agg[A1](col1: TypedColumn[A1]): Dataset[(K, A1)] =
aggUntyped(col1).asInstanceOf[Dataset[(K, A1)]]
/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
*/
def agg[A1, A2](col1: TypedColumn[A1], col2: TypedColumn[A2]): Dataset[(K, A1, A2)] =
aggUntyped(col1, col2).asInstanceOf[Dataset[(K, A1, A2)]]
/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
*/
def agg[A1, A2, A3](
col1: TypedColumn[A1],
col2: TypedColumn[A2],
col3: TypedColumn[A3]): Dataset[(K, A1, A2, A3)] =
aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, A1, A2, A3)]]
/**
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
* and the result of computing these aggregations over all elements in the group.
*/
def agg[A1, A2, A3, A4](
col1: TypedColumn[A1],
col2: TypedColumn[A2],
col3: TypedColumn[A3],
col4: TypedColumn[A4]): Dataset[(K, A1, A2, A3, A4)] =
aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, A1, A2, A3, A4)]]
/**
* Returns a [[Dataset]] that contains a tuple with each key and the number of items present
* for that key.
*/
def count(): Dataset[(K, Long)] = agg(functions.count("*").as[Long])
/** /**
* Applies the given function to each cogrouped data. For each unique group, the function will * Applies the given function to each cogrouped data. For each unique group, the function will
* be passed the grouping key and 2 iterators containing all elements in the group from * be passed the grouping key and 2 iterators containing all elements in the group from
......
...@@ -258,6 +258,42 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ...@@ -258,6 +258,42 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
(ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1)) (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1))
} }
test("typed aggregation: expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
checkAnswer(
ds.groupBy(_._1).agg(sum("_2").as[Int]),
("a", 30), ("b", 3), ("c", 1))
}
test("typed aggregation: expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
checkAnswer(
ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long]),
("a", 30, 32L), ("b", 3, 5L), ("c", 1, 2L))
}
test("typed aggregation: expr, expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
checkAnswer(
ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long], count("*").as[Long]),
("a", 30, 32L, 2L), ("b", 3, 5L, 2L), ("c", 1, 2L, 1L))
}
test("typed aggregation: expr, expr, expr, expr") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
checkAnswer(
ds.groupBy(_._1).agg(
sum("_2").as[Int],
sum($"_2" + 1).as[Long],
count("*").as[Long],
avg("_2").as[Double]),
("a", 30, 32L, 2L, 15.0), ("b", 3, 5L, 2L, 1.5), ("c", 1, 2L, 1L, 1.0))
}
test("cogroup") { test("cogroup") {
val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS() val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS()
val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS() val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment