From 968acf3bd9a502fcad15df3e53e359695ae702cc Mon Sep 17 00:00:00 2001 From: Michael Armbrust <michael@databricks.com> Date: Fri, 20 Nov 2015 15:36:30 -0800 Subject: [PATCH] [SPARK-11889][SQL] Fix type inference for GroupedDataset.agg in REPL In this PR I delete a method that breaks type inference for aggregators (only in the REPL) The error when this method is present is: ``` <console>:38: error: missing parameter type for expanded function ((x$2) => x$2._2) ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect() ``` Author: Michael Armbrust <michael@databricks.com> Closes #9870 from marmbrus/dataset-repl-agg. --- .../org/apache/spark/repl/ReplSuite.scala | 24 +++++++++++++++++ .../org/apache/spark/sql/GroupedDataset.scala | 27 +++---------------- .../apache/spark/sql/JavaDatasetSuite.java | 8 +++--- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 081aa03002..cbcccb11f1 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -339,6 +339,30 @@ class ReplSuite extends SparkFunSuite { } } + test("Datasets agg type-inference") { + val output = runInterpreter("local", + """ + |import org.apache.spark.sql.functions._ + |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.expressions.Aggregator + |import org.apache.spark.sql.TypedColumn + |/** An `Aggregator` that adds up any numeric type returned by the given function. */ + |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { + | val numeric = implicitly[Numeric[N]] + | override def zero: N = numeric.zero + | override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2) + | override def finish(reduction: N): N = reduction + |} + | + |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn + |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS() + |ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect() + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } + test("collecting objects of class defined in repl") { val output = runInterpreter("local[2]", """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 6de3dd6265..263f049104 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -146,31 +146,10 @@ class GroupedDataset[K, T] private[sql]( reduce(f.call _) } - /** - * Compute aggregates by specifying a series of aggregate columns, and return a [[DataFrame]]. - * We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to [[Dataset]] again. - * - * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. - * - * {{{ - * // Selects the age of the oldest employee and the aggregate expense for each department - * - * // Scala: - * import org.apache.spark.sql.functions._ - * df.groupBy("department").agg(max("age"), sum("expense")) - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df.groupBy("department").agg(max("age"), sum("expense")); - * }}} - * - * We can also use `Aggregator.toColumn` to pass in typed aggregate functions. - * - * @since 1.6.0 - */ + // This is here to prevent us from adding overloads that would be ambiguous. @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DataFrame = - groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*) + private def agg(exprs: Column*): DataFrame = + groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*) private def withEncoder(c: Column): Column = c match { case tc: TypedColumn[_, _] => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index ce40dd856f..f7249b8945 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -404,11 +404,9 @@ public class JavaDatasetSuite implements Serializable { grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); - Dataset<Tuple4<String, Integer, Long, Long>> agged2 = grouped.agg( - new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()), - expr("sum(_2)"), - count("*")) - .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG())); + Dataset<Tuple2<String, Integer>> agged2 = grouped.agg( + new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())) + .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); Assert.assertEquals( Arrays.asList( new Tuple4<>("a", 3, 3L, 2L), -- GitLab