diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8957df0be6814a8b4a3af08acc14e0f5509a0700..9ab5c299d0f557f1fd2f86e485b4a2c1e8ae5265 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -254,6 +254,10 @@ case class AttributeReference( } override def toString: String = s"$name#${exprId.id}$typeSuffix" + + // Since the expression id is not in the first constructor it is missing from the default + // tree string. + override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 500227e93a472ab5a13fc42393547de6615b0af5..4bca9c3b3fe54fd809a190325ec0eaa6f975fa9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -55,7 +55,7 @@ import org.apache.spark.sql.types.StructType * @since 1.6.0 */ @Experimental -class Dataset[T] private( +class Dataset[T] private[sql]( @transient val sqlContext: SQLContext, @transient val queryExecution: QueryExecution, unresolvedEncoder: Encoder[T]) extends Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 96d6e9dd548e53b4b28a9cc89d3e46f0127c2762..b8fc373dffcf59e4898595cc61ec3cb2a276bb5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,16 +17,25 @@ package org.apache.spark.sql +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution /** + * :: Experimental :: * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not * construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing * [[Dataset]]. + * + * COMPATIBILITY NOTE: Long term we plan to make [[GroupedDataset)]] extend `GroupedData`. However, + * making this change to the class hierarchy would break some function signatures. As such, this + * class should be considered a preview of the final API. Changes will be made to the interface + * after Spark 1.6. */ +@Experimental class GroupedDataset[K, T] private[sql]( private val kEncoder: Encoder[K], private val tEncoder: Encoder[T], @@ -35,7 +44,7 @@ class GroupedDataset[K, T] private[sql]( private val groupingAttributes: Seq[Attribute]) extends Serializable { private implicit val kEnc = kEncoder match { - case e: ExpressionEncoder[K] => e.resolve(groupingAttributes) + case e: ExpressionEncoder[K] => e.unbind(groupingAttributes).resolve(groupingAttributes) case other => throw new UnsupportedOperationException("Only expression encoders are currently supported") } @@ -46,9 +55,16 @@ class GroupedDataset[K, T] private[sql]( throw new UnsupportedOperationException("Only expression encoders are currently supported") } + /** Encoders for built in aggregations. */ + private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) + private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext + private def groupedData = + new GroupedData( + new DataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType) + /** * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. @@ -88,6 +104,79 @@ class GroupedDataset[K, T] private[sql]( MapGroups(f, groupingAttributes, logicalPlan)) } + // To ensure valid overloading. + protected def agg(expr: Column, exprs: Column*): DataFrame = + groupedData.agg(expr, exprs: _*) + + /** + * Internal helper function for building typed aggregations that return tuples. For simplicity + * and code reuse, we do this without the help of the type system and then use helper functions + * that cast appropriately for the user facing interface. + * TODO: does not handle aggrecations that return nonflat results, + */ + protected def aggUntyped(columns: TypedColumn[_]*): Dataset[_] = { + val aliases = (groupingAttributes ++ columns.map(_.expr)).map { + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + + val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan) + val execution = new QueryExecution(sqlContext, unresolvedPlan) + + val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]) + + // Rebind the encoders to the nested schema that will be produced by the aggregation. + val encoders = (kEnc +: columnEncoders).zip(execution.analyzed.output).map { + case (e: ExpressionEncoder[_], a) if !e.flat => + e.nested(a).resolve(execution.analyzed.output) + case (e, a) => + e.unbind(a :: Nil).resolve(execution.analyzed.output) + } + new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) + } + + /** + * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key + * and the result of computing this aggregation over all elements in the group. + */ + def agg[A1](col1: TypedColumn[A1]): Dataset[(K, A1)] = + aggUntyped(col1).asInstanceOf[Dataset[(K, A1)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + */ + def agg[A1, A2](col1: TypedColumn[A1], col2: TypedColumn[A2]): Dataset[(K, A1, A2)] = + aggUntyped(col1, col2).asInstanceOf[Dataset[(K, A1, A2)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + */ + def agg[A1, A2, A3]( + col1: TypedColumn[A1], + col2: TypedColumn[A2], + col3: TypedColumn[A3]): Dataset[(K, A1, A2, A3)] = + aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, A1, A2, A3)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + */ + def agg[A1, A2, A3, A4]( + col1: TypedColumn[A1], + col2: TypedColumn[A2], + col3: TypedColumn[A3], + col4: TypedColumn[A4]): Dataset[(K, A1, A2, A3, A4)] = + aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, A1, A2, A3, A4)]] + + /** + * Returns a [[Dataset]] that contains a tuple with each key and the number of items present + * for that key. + */ + def count(): Dataset[(K, Long)] = agg(functions.count("*").as[Long]) + /** * Applies the given function to each cogrouped data. For each unique group, the function will * be passed the grouping key and 2 iterators containing all elements in the group from diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3e9b621cfd67f8d5ab92dc104cf39c89c6012f53..d61e17edc64edaf6c65b5a03485343633f1fbe93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -258,6 +258,42 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1)) } + test("typed aggregation: expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum("_2").as[Int]), + ("a", 30), ("b", 3), ("c", 1)) + } + + test("typed aggregation: expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long]), + ("a", 30, 32L), ("b", 3, 5L), ("c", 1, 2L)) + } + + test("typed aggregation: expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long], count("*").as[Long]), + ("a", 30, 32L, 2L), ("b", 3, 5L, 2L), ("c", 1, 2L, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + sum("_2").as[Int], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double]), + ("a", 30, 32L, 2L, 15.0), ("b", 3, 5L, 2L, 1.5), ("c", 1, 2L, 1L, 1.0)) + } + test("cogroup") { val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS() val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS()