diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index f240c30427b45ec22012dc658b8326676e8d6955..290de794dc3bbf7903e278dec9fdaf9b18d51e32 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -329,6 +329,11 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"), + // [SPARK-14451][SQL] Move encoder definition into Aggregator interface + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.toColumn"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.bufferEncoder"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.outputEncoder"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions") diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index c8b78bc14a39c2d77f5a536f41e3bd469be9d2b1..547da8f713ac77f97d8cc767495991bc881cdc55 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -285,7 +285,7 @@ class ReplSuite extends SparkFunSuite { val output = runInterpreter("local", """ |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.{Encoder, Encoders} |import org.apache.spark.sql.expressions.Aggregator |import org.apache.spark.sql.TypedColumn |val simpleSum = new Aggregator[Int, Int, Int] { @@ -293,6 +293,8 @@ class ReplSuite extends SparkFunSuite { | def reduce(b: Int, a: Int) = b + a // Add an element to the running total | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. | def finish(b: Int) = b // Return the final result. + | def bufferEncoder: Encoder[Int] = Encoders.scalaInt + | def outputEncoder: Encoder[Int] = Encoders.scalaInt |}.toColumn | |val ds = Seq(1, 2, 3, 4).toDS() @@ -339,30 +341,6 @@ class ReplSuite extends SparkFunSuite { } } - test("Datasets agg type-inference") { - val output = runInterpreter("local", - """ - |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.Encoder - |import org.apache.spark.sql.expressions.Aggregator - |import org.apache.spark.sql.TypedColumn - |/** An `Aggregator` that adds up any numeric type returned by the given function. */ - |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { - | val numeric = implicitly[Numeric[N]] - | override def zero: N = numeric.zero - | override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) - | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2) - | override def finish(reduction: N): N = reduction - |} - | - |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn - |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS() - |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect() - """.stripMargin) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - } - test("collecting objects of class defined in repl") { val output = runInterpreter("local[2]", """ diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index dbfacba34637d27694d979a3155da041f7a03f94..7e10f15226b7abf77c73defc84f18ce36df411df 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -267,7 +267,7 @@ class ReplSuite extends SparkFunSuite { val output = runInterpreter("local", """ |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.{Encoder, Encoders} |import org.apache.spark.sql.expressions.Aggregator |import org.apache.spark.sql.TypedColumn |val simpleSum = new Aggregator[Int, Int, Int] { @@ -275,6 +275,8 @@ class ReplSuite extends SparkFunSuite { | def reduce(b: Int, a: Int) = b + a // Add an element to the running total | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. | def finish(b: Int) = b // Return the final result. + | def bufferEncoder: Encoder[Int] = Encoders.scalaInt + | def outputEncoder: Encoder[Int] = Encoders.scalaInt |}.toColumn | |val ds = Seq(1, 2, 3, 4).toDS() @@ -321,31 +323,6 @@ class ReplSuite extends SparkFunSuite { } } - test("Datasets agg type-inference") { - val output = runInterpreter("local", - """ - |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.Encoder - |import org.apache.spark.sql.expressions.Aggregator - |import org.apache.spark.sql.TypedColumn - |/** An `Aggregator` that adds up any numeric type returned by the given function. */ - |class SumOf[I, N : Numeric](f: I => N) extends - | org.apache.spark.sql.expressions.Aggregator[I, N, N] { - | val numeric = implicitly[Numeric[N]] - | override def zero: N = numeric.zero - | override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) - | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2) - | override def finish(reduction: N): N = reduction - |} - | - |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn - |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS() - |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect() - """.stripMargin) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - } - test("collecting objects of class defined in repl") { val output = runInterpreter("local[2]", """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala index 7a18d0afce6bafdd1130a9cdac54a9428bca2aa3..c39a78da6f9be7bbab752f090f7b181e81c542d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.api.java.function.MapFunction -import org.apache.spark.sql.TypedColumn +import org.apache.spark.sql.{Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator @@ -27,28 +27,20 @@ import org.apache.spark.sql.expressions.Aggregator //////////////////////////////////////////////////////////////////////////////////////////////////// -class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] { - val numeric = implicitly[Numeric[OUT]] - override def zero: OUT = numeric.zero - override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a)) - override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2) - override def finish(reduction: OUT): OUT = reduction - - // TODO(ekl) java api support once this is exposed in scala -} - - class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] { override def zero: Double = 0.0 override def reduce(b: Double, a: IN): Double = b + f(a) override def merge(b1: Double, b2: Double): Double = b1 + b2 override def finish(reduction: Double): Double = reduction + override def bufferEncoder: Encoder[Double] = ExpressionEncoder[Double]() + override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]() + // Java api support def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) - def toColumnJava(): TypedColumn[IN, java.lang.Double] = { - toColumn(ExpressionEncoder(), ExpressionEncoder()) - .asInstanceOf[TypedColumn[IN, java.lang.Double]] + + def toColumnJava: TypedColumn[IN, java.lang.Double] = { + toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]] } } @@ -59,11 +51,14 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] { override def merge(b1: Long, b2: Long): Long = b1 + b2 override def finish(reduction: Long): Long = reduction + override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]() + override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]() + // Java api support def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long]) - def toColumnJava(): TypedColumn[IN, java.lang.Long] = { - toColumn(ExpressionEncoder(), ExpressionEncoder()) - .asInstanceOf[TypedColumn[IN, java.lang.Long]] + + def toColumnJava: TypedColumn[IN, java.lang.Long] = { + toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]] } } @@ -76,11 +71,13 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] { override def merge(b1: Long, b2: Long): Long = b1 + b2 override def finish(reduction: Long): Long = reduction + override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]() + override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]() + // Java api support def this(f: MapFunction[IN, Object]) = this(x => f.call(x)) - def toColumnJava(): TypedColumn[IN, java.lang.Long] = { - toColumn(ExpressionEncoder(), ExpressionEncoder()) - .asInstanceOf[TypedColumn[IN, java.lang.Long]] + def toColumnJava: TypedColumn[IN, java.lang.Long] = { + toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]] } } @@ -93,10 +90,12 @@ class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), D (b1._1 + b2._1, b1._2 + b2._2) } + override def bufferEncoder: Encoder[(Double, Long)] = ExpressionEncoder[(Double, Long)]() + override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]() + // Java api support def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) - def toColumnJava(): TypedColumn[IN, java.lang.Double] = { - toColumn(ExpressionEncoder(), ExpressionEncoder()) - .asInstanceOf[TypedColumn[IN, java.lang.Double]] + def toColumnJava: TypedColumn[IN, java.lang.Double] = { + toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 9cb356f1ca375b146f6af3cf3bd4229975938b01..7da8379c9aa9aae740d5aa6b99498841553a3349 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -43,52 +43,65 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression * * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird * - * @tparam I The input type for the aggregation. - * @tparam B The type of the intermediate value of the reduction. - * @tparam O The type of the final output result. + * @tparam IN The input type for the aggregation. + * @tparam BUF The type of the intermediate value of the reduction. + * @tparam OUT The type of the final output result. * @since 1.6.0 */ -abstract class Aggregator[-I, B, O] extends Serializable { +abstract class Aggregator[-IN, BUF, OUT] extends Serializable { /** * A zero value for this aggregation. Should satisfy the property that any b + zero = b. * @since 1.6.0 */ - def zero: B + def zero: BUF /** * Combine two values to produce a new value. For performance, the function may modify `b` and * return it instead of constructing new object for b. * @since 1.6.0 */ - def reduce(b: B, a: I): B + def reduce(b: BUF, a: IN): BUF /** * Merge two intermediate values. * @since 1.6.0 */ - def merge(b1: B, b2: B): B + def merge(b1: BUF, b2: BUF): BUF /** * Transform the output of the reduction. * @since 1.6.0 */ - def finish(reduction: B): O + def finish(reduction: BUF): OUT /** - * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]] + * Specifies the [[Encoder]] for the intermediate value type. + * @since 2.0.0 + */ + def bufferEncoder: Encoder[BUF] + + /** + * Specifies the [[Encoder]] for the final ouput value type. + * @since 2.0.0 + */ + def outputEncoder: Encoder[OUT] + + /** + * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]]. * operations. * @since 1.6.0 */ - def toColumn( - implicit bEncoder: Encoder[B], - cEncoder: Encoder[O]): TypedColumn[I, O] = { + def toColumn: TypedColumn[IN, OUT] = { + implicit val bEncoder = bufferEncoder + implicit val cEncoder = outputEncoder + val expr = AggregateExpression( TypedAggregateExpression(this), Complete, isDistinct = false) - new TypedColumn[I, O](expr, encoderFor[O]) + new TypedColumn[IN, OUT](expr, encoderFor[OUT]) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java index 8cb174b9067f34f6b816e09a25a2d36cbddf0e82..0e49f871de5c4e4c06a33f5e9b7c4a20899a7148 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -26,6 +26,7 @@ import org.junit.Test; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.KeyValueGroupedDataset; import org.apache.spark.sql.expressions.Aggregator; @@ -39,12 +40,10 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase { public void testTypedAggregationAnonClass() { KeyValueGroupedDataset<String, Tuple2<String, Integer>> grouped = generateGroupedDataset(); - Dataset<Tuple2<String, Integer>> agged = - grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); + Dataset<Tuple2<String, Integer>> agged = grouped.agg(new IntSumOf().toColumn()); Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); - Dataset<Tuple2<String, Integer>> agged2 = grouped.agg( - new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())) + Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(new IntSumOf().toColumn()) .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); Assert.assertEquals( Arrays.asList( @@ -73,6 +72,16 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase { public Integer finish(Integer reduction) { return reduction; } + + @Override + public Encoder<Integer> bufferEncoder() { + return Encoders.INT(); + } + + @Override + public Encoder<Integer> outputEncoder() { + return Encoders.INT(); + } } @Test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 08b3389ad918abcf32a368f48c1e1b56ae53c60f..3a7215ee39728345ad8215eb86dfe21ce1b1d6a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.language.postfixOps +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ @@ -26,74 +27,65 @@ import org.apache.spark.sql.test.SharedSQLContext object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { - override def zero: (Long, Long) = (0, 0) - override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { (countAndSum._1 + 1, countAndSum._2 + input._2) } - override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { (b1._1 + b2._1, b1._2 + b2._2) } - override def finish(reduction: (Long, Long)): (Long, Long) = reduction + override def bufferEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)] + override def outputEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)] } + case class AggData(a: Int, b: String) + object ClassInputAgg extends Aggregator[AggData, Int, Int] { - /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ override def zero: Int = 0 - - /** - * Combine two values to produce a new value. For performance, the function may modify `b` and - * return it instead of constructing new object for b. - */ override def reduce(b: Int, a: AggData): Int = b + a.a - - /** - * Transform the output of the reduction. - */ override def finish(reduction: Int): Int = reduction - - /** - * Merge two intermediate values - */ override def merge(b1: Int, b2: Int): Int = b1 + b2 + override def bufferEncoder: Encoder[Int] = Encoders.scalaInt + override def outputEncoder: Encoder[Int] = Encoders.scalaInt } + object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { - /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ override def zero: (Int, AggData) = 0 -> AggData(0, "0") - - /** - * Combine two values to produce a new value. For performance, the function may modify `b` and - * return it instead of constructing new object for b. - */ override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) - - /** - * Transform the output of the reduction. - */ override def finish(reduction: (Int, AggData)): Int = reduction._1 - - /** - * Merge two intermediate values - */ override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) = (b1._1 + b2._1, b1._2) + override def bufferEncoder: Encoder[(Int, AggData)] = Encoders.product[(Int, AggData)] + override def outputEncoder: Encoder[Int] = Encoders.scalaInt } + object NameAgg extends Aggregator[AggData, String, String] { def zero: String = "" - def reduce(b: String, a: AggData): String = a.b + b - def merge(b1: String, b2: String): String = b1 + b2 - def finish(r: String): String = r + override def bufferEncoder: Encoder[String] = Encoders.STRING + override def outputEncoder: Encoder[String] = Encoders.STRING +} + + +class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT) + extends Aggregator[IN, OUT, OUT] { + + private val numeric = implicitly[Numeric[OUT]] + override def zero: OUT = numeric.zero + override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a)) + override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2) + override def finish(reduction: OUT): OUT = reduction + override def bufferEncoder: Encoder[OUT] = implicitly[Encoder[OUT]] + override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]] } + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -187,6 +179,19 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L)) } + test("generic typed sum") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + checkDataset( + ds.groupByKey(_._1) + .agg(new ParameterizedTypeSum[(String, Int), Double](_._2.toDouble).toColumn), + ("a", 4.0), ("b", 3.0)) + + checkDataset( + ds.groupByKey(_._1) + .agg(new ParameterizedTypeSum((x: (String, Int)) => x._2.toInt).toColumn), + ("a", 4), ("b", 3)) + } + test("SPARK-12555 - result should not be corrupted after input columns are reordered") { val ds = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData]