Skip to content
Snippets Groups Projects
Commit 8d575246 authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-11933][SQL] Rename mapGroup -> mapGroups and flatMapGroup -> flatMapGroups.

Based on feedback from Matei, this is more consistent with mapPartitions in Spark.

Also addresses some of the cleanups from a previous commit that renames the type variables.

Author: Reynold Xin <rxin@databricks.com>

Closes #9919 from rxin/SPARK-11933.
parent 026ea2ea
No related branches found
No related tags found
No related merge requests found
......@@ -23,6 +23,6 @@ import java.util.Iterator;
/**
* A function that returns zero or more output records from each grouping key and its values.
*/
public interface FlatMapGroupFunction<K, V, R> extends Serializable {
public interface FlatMapGroupsFunction<K, V, R> extends Serializable {
Iterable<R> call(K key, Iterator<V> values) throws Exception;
}
......@@ -23,6 +23,6 @@ import java.util.Iterator;
/**
* Base interface for a map function used in GroupedDataset's mapGroup function.
*/
public interface MapGroupFunction<K, V, R> extends Serializable {
public interface MapGroupsFunction<K, V, R> extends Serializable {
R call(K key, Iterator<V> values) throws Exception;
}
......@@ -43,7 +43,7 @@ import org.apache.spark.sql.expressions.Aggregator
@Experimental
class GroupedDataset[K, V] private[sql](
kEncoder: Encoder[K],
tEncoder: Encoder[V],
vEncoder: Encoder[V],
val queryExecution: QueryExecution,
private val dataAttributes: Seq[Attribute],
private val groupingAttributes: Seq[Attribute]) extends Serializable {
......@@ -53,12 +53,12 @@ class GroupedDataset[K, V] private[sql](
// queryexecution.
private implicit val unresolvedKEncoder = encoderFor(kEncoder)
private implicit val unresolvedTEncoder = encoderFor(tEncoder)
private implicit val unresolvedVEncoder = encoderFor(vEncoder)
private val resolvedKEncoder =
unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes)
private val resolvedTEncoder =
unresolvedTEncoder.resolve(dataAttributes, OuterScopes.outerScopes)
private val resolvedVEncoder =
unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes)
private def logicalPlan = queryExecution.analyzed
private def sqlContext = queryExecution.sqlContext
......@@ -76,7 +76,7 @@ class GroupedDataset[K, V] private[sql](
def keyAs[L : Encoder]: GroupedDataset[L, V] =
new GroupedDataset(
encoderFor[L],
unresolvedTEncoder,
unresolvedVEncoder,
queryExecution,
dataAttributes,
groupingAttributes)
......@@ -110,13 +110,13 @@ class GroupedDataset[K, V] private[sql](
*
* @since 1.6.0
*/
def flatMapGroup[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
new Dataset[U](
sqlContext,
MapGroups(
f,
resolvedKEncoder,
resolvedTEncoder,
resolvedVEncoder,
groupingAttributes,
logicalPlan))
}
......@@ -138,8 +138,8 @@ class GroupedDataset[K, V] private[sql](
*
* @since 1.6.0
*/
def flatMapGroup[U](f: FlatMapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
flatMapGroup((key, data) => f.call(key, data.asJava).asScala)(encoder)
def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
flatMapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder)
}
/**
......@@ -158,9 +158,9 @@ class GroupedDataset[K, V] private[sql](
*
* @since 1.6.0
*/
def mapGroup[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
def mapGroups[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
val func = (key: K, it: Iterator[V]) => Iterator(f(key, it))
flatMapGroup(func)
flatMapGroups(func)
}
/**
......@@ -179,8 +179,8 @@ class GroupedDataset[K, V] private[sql](
*
* @since 1.6.0
*/
def mapGroup[U](f: MapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
mapGroup((key, data) => f.call(key, data.asJava))(encoder)
def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
mapGroups((key, data) => f.call(key, data.asJava))(encoder)
}
/**
......@@ -192,8 +192,8 @@ class GroupedDataset[K, V] private[sql](
def reduce(f: (V, V) => V): Dataset[(K, V)] = {
val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder)
flatMapGroup(func)
implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder)
flatMapGroups(func)
}
/**
......@@ -213,7 +213,7 @@ class GroupedDataset[K, V] private[sql](
private def withEncoder(c: Column): Column = c match {
case tc: TypedColumn[_, _] =>
tc.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes)
tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes)
case _ => c
}
......@@ -227,7 +227,7 @@ class GroupedDataset[K, V] private[sql](
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(
_.withInputType(resolvedTEncoder, dataAttributes).named)
_.withInputType(resolvedVEncoder, dataAttributes).named)
val keyColumn = if (groupingAttributes.length > 1) {
Alias(CreateStruct(groupingAttributes), "key")()
} else {
......@@ -304,7 +304,7 @@ class GroupedDataset[K, V] private[sql](
def cogroup[U, R : Encoder](
other: GroupedDataset[K, U])(
f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
implicit def uEnc: Encoder[U] = other.unresolvedTEncoder
implicit def uEnc: Encoder[U] = other.unresolvedVEncoder
new Dataset[R](
sqlContext,
CoGroup(
......
......@@ -170,7 +170,7 @@ public class JavaDatasetSuite implements Serializable {
}
}, Encoders.INT());
Dataset<String> mapped = grouped.mapGroup(new MapGroupFunction<Integer, String, String>() {
Dataset<String> mapped = grouped.mapGroups(new MapGroupsFunction<Integer, String, String>() {
@Override
public String call(Integer key, Iterator<String> values) throws Exception {
StringBuilder sb = new StringBuilder(key.toString());
......@@ -183,8 +183,8 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
Dataset<String> flatMapped = grouped.flatMapGroup(
new FlatMapGroupFunction<Integer, String, String>() {
Dataset<String> flatMapped = grouped.flatMapGroups(
new FlatMapGroupsFunction<Integer, String, String>() {
@Override
public Iterable<String> call(Integer key, Iterator<String> values) throws Exception {
StringBuilder sb = new StringBuilder(key.toString());
......@@ -249,8 +249,8 @@ public class JavaDatasetSuite implements Serializable {
GroupedDataset<Integer, String> grouped =
ds.groupBy(length(col("value"))).keyAs(Encoders.INT());
Dataset<String> mapped = grouped.mapGroup(
new MapGroupFunction<Integer, String, String>() {
Dataset<String> mapped = grouped.mapGroups(
new MapGroupsFunction<Integer, String, String>() {
@Override
public String call(Integer key, Iterator<String> data) throws Exception {
StringBuilder sb = new StringBuilder(key.toString());
......
......@@ -86,7 +86,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
test("groupBy function, map") {
val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS()
val grouped = ds.groupBy(_ % 2)
val agged = grouped.mapGroup { case (g, iter) =>
val agged = grouped.mapGroups { case (g, iter) =>
val name = if (g == 0) "even" else "odd"
(name, iter.size)
}
......@@ -99,7 +99,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
test("groupBy function, flatMap") {
val ds = Seq("a", "b", "c", "xyz", "hello").toDS()
val grouped = ds.groupBy(_.length)
val agged = grouped.flatMapGroup { case (g, iter) => Iterator(g.toString, iter.mkString) }
val agged = grouped.flatMapGroups { case (g, iter) => Iterator(g.toString, iter.mkString) }
checkAnswer(
agged,
......
......@@ -224,7 +224,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy function, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy(v => (v._1, "word"))
val agged = grouped.mapGroup { case (g, iter) => (g._1, iter.map(_._2).sum) }
val agged = grouped.mapGroups { case (g, iter) => (g._1, iter.map(_._2).sum) }
checkAnswer(
agged,
......@@ -234,7 +234,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy function, flatMap") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy(v => (v._1, "word"))
val agged = grouped.flatMapGroup { case (g, iter) =>
val agged = grouped.flatMapGroups { case (g, iter) =>
Iterator(g._1, iter.map(_._2).sum.toString)
}
......@@ -255,7 +255,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy columns, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1")
val agged = grouped.mapGroup { case (g, iter) => (g.getString(0), iter.map(_._2).sum) }
val agged = grouped.mapGroups { case (g, iter) => (g.getString(0), iter.map(_._2).sum) }
checkAnswer(
agged,
......@@ -265,7 +265,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy columns asKey, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1").keyAs[String]
val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) }
val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
checkAnswer(
agged,
......@@ -275,7 +275,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy columns asKey tuple, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1", lit(1)).keyAs[(String, Int)]
val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) }
val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
checkAnswer(
agged,
......@@ -285,7 +285,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("groupBy columns asKey class, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).keyAs[ClassData]
val agged = grouped.mapGroup { case (g, iter) => (g, iter.map(_._2).sum) }
val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
checkAnswer(
agged,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment