Skip to content
Snippets Groups Projects
Commit b2d195e1 authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-11554][SQL] add map/flatMap to GroupedDataset

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9521 from cloud-fan/map.
parent 26739059
No related branches found
No related tags found
No related merge requests found
......@@ -494,7 +494,7 @@ case class AppendColumn[T, U](
/** Factory for constructing new `MapGroups` nodes. */
object MapGroups {
def apply[K : Encoder, T : Encoder, U : Encoder](
func: (K, Iterator[T]) => Iterator[U],
func: (K, Iterator[T]) => TraversableOnce[U],
groupingAttributes: Seq[Attribute],
child: LogicalPlan): MapGroups[K, T, U] = {
new MapGroups(
......@@ -514,7 +514,7 @@ object MapGroups {
* object representation of all the rows with that key.
*/
case class MapGroups[K, T, U](
func: (K, Iterator[T]) => Iterator[U],
func: (K, Iterator[T]) => TraversableOnce[U],
kEncoder: ExpressionEncoder[K],
tEncoder: ExpressionEncoder[T],
uEncoder: ExpressionEncoder[U],
......
......@@ -102,16 +102,39 @@ class GroupedDataset[K, T] private[sql](
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
* constraints of their cluster.
*/
def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = {
def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = {
new Dataset[U](
sqlContext,
MapGroups(f, groupingAttributes, logicalPlan))
}
def mapGroups[U](
def flatMap[U](
f: JFunction2[K, JIterator[T], JIterator[U]],
encoder: Encoder[U]): Dataset[U] = {
mapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder)
flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder)
}
/**
* Applies the given function to each group of data. For each unique group, the function will
* be passed the group key and an iterator that contains all of the elements in the group. The
* function can return an element of arbitrary type which will be returned as a new [[Dataset]].
*
* Internally, the implementation will spill to disk if any given group is too large to fit into
* memory. However, users must take care to avoid materializing the whole iterator for a group
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
* constraints of their cluster.
*/
def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = {
val func = (key: K, it: Iterator[T]) => Iterator(f(key, it))
new Dataset[U](
sqlContext,
MapGroups(func, groupingAttributes, logicalPlan))
}
def map[U](
f: JFunction2[K, JIterator[T], U],
encoder: Encoder[U]): Dataset[U] = {
map((key, data) => f.call(key, data.asJava))(encoder)
}
// To ensure valid overloading.
......
......@@ -356,7 +356,7 @@ case class AppendColumns[T, U](
* being output.
*/
case class MapGroups[K, T, U](
func: (K, Iterator[T]) => Iterator[U],
func: (K, Iterator[T]) => TraversableOnce[U],
kEncoder: ExpressionEncoder[K],
tEncoder: ExpressionEncoder[T],
uEncoder: ExpressionEncoder[U],
......
......@@ -170,15 +170,15 @@ public class JavaDatasetSuite implements Serializable {
}
}, e.INT());
Dataset<String> mapped = grouped.mapGroups(
new Function2<Integer, Iterator<String>, Iterator<String>>() {
Dataset<String> mapped = grouped.map(
new Function2<Integer, Iterator<String>, String>() {
@Override
public Iterator<String> call(Integer key, Iterator<String> data) throws Exception {
public String call(Integer key, Iterator<String> data) throws Exception {
StringBuilder sb = new StringBuilder(key.toString());
while (data.hasNext()) {
sb.append(data.next());
}
return Collections.singletonList(sb.toString()).iterator();
return sb.toString();
}
},
e.STRING());
......@@ -224,15 +224,15 @@ public class JavaDatasetSuite implements Serializable {
Dataset<String> ds = context.createDataset(data, e.STRING());
GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT());
Dataset<String> mapped = grouped.mapGroups(
new Function2<Integer, Iterator<String>, Iterator<String>>() {
Dataset<String> mapped = grouped.map(
new Function2<Integer, Iterator<String>, String>() {
@Override
public Iterator<String> call(Integer key, Iterator<String> data) throws Exception {
public String call(Integer key, Iterator<String> data) throws Exception {
StringBuilder sb = new StringBuilder(key.toString());
while (data.hasNext()) {
sb.append(data.next());
}
return Collections.singletonList(sb.toString()).iterator();
return sb.toString();
}
},
e.STRING());
......
......@@ -88,16 +88,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
0, 1)
}
test("groupBy function, mapGroups") {
test("groupBy function, map") {
val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS()
val grouped = ds.groupBy(_ % 2)
val agged = grouped.mapGroups { case (g, iter) =>
val agged = grouped.map { case (g, iter) =>
val name = if (g == 0) "even" else "odd"
Iterator((name, iter.size))
(name, iter.size)
}
checkAnswer(
agged,
("even", 5), ("odd", 6))
}
test("groupBy function, flatMap") {
val ds = Seq("a", "b", "c", "xyz", "hello").toDS()
val grouped = ds.groupBy(_.length)
val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) }
checkAnswer(
agged,
"1", "abc", "3", "xyz", "5", "hello")
}
}
......@@ -198,60 +198,60 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
(1, 1))
}
test("groupBy function, mapGroups") {
test("groupBy function, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy(v => (v._1, "word"))
val agged = grouped.mapGroups { case (g, iter) =>
Iterator((g._1, iter.map(_._2).sum))
}
val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) }
checkAnswer(
agged,
("a", 30), ("b", 3), ("c", 1))
}
test("groupBy columns, mapGroups") {
test("groupBy function, fatMap") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy(v => (v._1, "word"))
val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) }
checkAnswer(
agged,
"a", "30", "b", "3", "c", "1")
}
test("groupBy columns, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1")
val agged = grouped.mapGroups { case (g, iter) =>
Iterator((g.getString(0), iter.map(_._2).sum))
}
val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) }
checkAnswer(
agged,
("a", 30), ("b", 3), ("c", 1))
}
test("groupBy columns asKey, mapGroups") {
test("groupBy columns asKey, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1").asKey[String]
val agged = grouped.mapGroups { case (g, iter) =>
Iterator((g, iter.map(_._2).sum))
}
val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
checkAnswer(
agged,
("a", 30), ("b", 3), ("c", 1))
}
test("groupBy columns asKey tuple, mapGroups") {
test("groupBy columns asKey tuple, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)]
val agged = grouped.mapGroups { case (g, iter) =>
Iterator((g, iter.map(_._2).sum))
}
val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
checkAnswer(
agged,
(("a", 1), 30), (("b", 1), 3), (("c", 1), 1))
}
test("groupBy columns asKey class, mapGroups") {
test("groupBy columns asKey class, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData]
val agged = grouped.mapGroups { case (g, iter) =>
Iterator((g, iter.map(_._2).sum))
}
val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
checkAnswer(
agged,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment