diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 09aac00a455f97e012d2dc0dd9e8ad8d54d7da75..e151ac04ede2ad0ecc3ca06a3e8d8968bddcf542 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -494,7 +494,7 @@ case class AppendColumn[T, U]( /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { def apply[K : Encoder, T : Encoder, U : Encoder]( - func: (K, Iterator[T]) => Iterator[U], + func: (K, Iterator[T]) => TraversableOnce[U], groupingAttributes: Seq[Attribute], child: LogicalPlan): MapGroups[K, T, U] = { new MapGroups( @@ -514,7 +514,7 @@ object MapGroups { * object representation of all the rows with that key. */ case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => Iterator[U], + func: (K, Iterator[T]) => TraversableOnce[U], kEncoder: ExpressionEncoder[K], tEncoder: ExpressionEncoder[T], uEncoder: ExpressionEncoder[U], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index b2803d5a9a1e3caf3a84cb48a48dc412d31d32fc..5c3f6265458757be1bbc099416127a413701c4d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -102,16 +102,39 @@ class GroupedDataset[K, T] private[sql]( * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. */ - def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = { + def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapGroups(f, groupingAttributes, logicalPlan)) } - def mapGroups[U]( + def flatMap[U]( f: JFunction2[K, JIterator[T], JIterator[U]], encoder: Encoder[U]): Dataset[U] = { - mapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) + flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + */ + def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = { + val func = (key: K, it: Iterator[T]) => Iterator(f(key, it)) + new Dataset[U]( + sqlContext, + MapGroups(func, groupingAttributes, logicalPlan)) + } + + def map[U]( + f: JFunction2[K, JIterator[T], U], + encoder: Encoder[U]): Dataset[U] = { + map((key, data) => f.call(key, data.asJava))(encoder) } // To ensure valid overloading. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 799650a4f784f488c9fd9b654e3ce659f223b71e..2593b16b1c8d7bbb6b8ba51d9ac0b3c382f63073 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -356,7 +356,7 @@ case class AppendColumns[T, U]( * being output. */ case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => Iterator[U], + func: (K, Iterator[T]) => TraversableOnce[U], kEncoder: ExpressionEncoder[K], tEncoder: ExpressionEncoder[T], uEncoder: ExpressionEncoder[U], diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a9493d576d179d6040a401c629596963da7eca68..0d3b1a5af52c494bb93638f777d45d330c8f386b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -170,15 +170,15 @@ public class JavaDatasetSuite implements Serializable { } }, e.INT()); - Dataset<String> mapped = grouped.mapGroups( - new Function2<Integer, Iterator<String>, Iterator<String>>() { + Dataset<String> mapped = grouped.map( + new Function2<Integer, Iterator<String>, String>() { @Override - public Iterator<String> call(Integer key, Iterator<String> data) throws Exception { + public String call(Integer key, Iterator<String> data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (data.hasNext()) { sb.append(data.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return sb.toString(); } }, e.STRING()); @@ -224,15 +224,15 @@ public class JavaDatasetSuite implements Serializable { Dataset<String> ds = context.createDataset(data, e.STRING()); GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); - Dataset<String> mapped = grouped.mapGroups( - new Function2<Integer, Iterator<String>, Iterator<String>>() { + Dataset<String> mapped = grouped.map( + new Function2<Integer, Iterator<String>, String>() { @Override - public Iterator<String> call(Integer key, Iterator<String> data) throws Exception { + public String call(Integer key, Iterator<String> data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (data.hasNext()) { sb.append(data.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return sb.toString(); } }, e.STRING()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index e3b0346f857d33d02cdfc07a150e207d44de6cea..fcf03f7180984a77e61219fa94bfd48c2c0dac0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -88,16 +88,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { 0, 1) } - test("groupBy function, mapGroups") { + test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() val grouped = ds.groupBy(_ % 2) - val agged = grouped.mapGroups { case (g, iter) => + val agged = grouped.map { case (g, iter) => val name = if (g == 0) "even" else "odd" - Iterator((name, iter.size)) + (name, iter.size) } checkAnswer( agged, ("even", 5), ("odd", 6)) } + + test("groupBy function, flatMap") { + val ds = Seq("a", "b", "c", "xyz", "hello").toDS() + val grouped = ds.groupBy(_.length) + val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) } + + checkAnswer( + agged, + "1", "abc", "3", "xyz", "5", "hello") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d61e17edc64edaf6c65b5a03485343633f1fbe93..6f1174e6577e3a65bd0f6d3e8265ef57e5ee4804 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -198,60 +198,60 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (1, 1)) } - test("groupBy function, mapGroups") { + test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g._1, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns, mapGroups") { + test("groupBy function, fatMap") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy(v => (v._1, "word")) + val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } + + checkAnswer( + agged, + "a", "30", "b", "3", "c", "1") + } + + test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g.getString(0), iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns asKey, mapGroups") { + test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").asKey[String] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns asKey tuple, mapGroups") { + test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, (("a", 1), 30), (("b", 1), 3), (("c", 1), 1)) } - test("groupBy columns asKey class, mapGroups") { + test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged,