diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 09aac00a455f97e012d2dc0dd9e8ad8d54d7da75..e151ac04ede2ad0ecc3ca06a3e8d8968bddcf542 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -494,7 +494,7 @@ case class AppendColumn[T, U](
 /** Factory for constructing new `MapGroups` nodes. */
 object MapGroups {
   def apply[K : Encoder, T : Encoder, U : Encoder](
-      func: (K, Iterator[T]) => Iterator[U],
+      func: (K, Iterator[T]) => TraversableOnce[U],
       groupingAttributes: Seq[Attribute],
       child: LogicalPlan): MapGroups[K, T, U] = {
     new MapGroups(
@@ -514,7 +514,7 @@ object MapGroups {
  * object representation of all the rows with that key.
  */
 case class MapGroups[K, T, U](
-    func: (K, Iterator[T]) => Iterator[U],
+    func: (K, Iterator[T]) => TraversableOnce[U],
     kEncoder: ExpressionEncoder[K],
     tEncoder: ExpressionEncoder[T],
     uEncoder: ExpressionEncoder[U],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index b2803d5a9a1e3caf3a84cb48a48dc412d31d32fc..5c3f6265458757be1bbc099416127a413701c4d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -102,16 +102,39 @@ class GroupedDataset[K, T] private[sql](
    * (for example, by calling `toList`) unless they are sure that this is possible given the memory
    * constraints of their cluster.
    */
-  def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = {
+  def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = {
     new Dataset[U](
       sqlContext,
       MapGroups(f, groupingAttributes, logicalPlan))
   }
 
-  def mapGroups[U](
+  def flatMap[U](
       f: JFunction2[K, JIterator[T], JIterator[U]],
       encoder: Encoder[U]): Dataset[U] = {
-    mapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder)
+    flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder)
+  }
+
+  /**
+   * Applies the given function to each group of data.  For each unique group, the function will
+   * be passed the group key and an iterator that contains all of the elements in the group. The
+   * function can return an element of arbitrary type which will be returned as a new [[Dataset]].
+   *
+   * Internally, the implementation will spill to disk if any given group is too large to fit into
+   * memory.  However, users must take care to avoid materializing the whole iterator for a group
+   * (for example, by calling `toList`) unless they are sure that this is possible given the memory
+   * constraints of their cluster.
+   */
+  def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = {
+    val func = (key: K, it: Iterator[T]) => Iterator(f(key, it))
+    new Dataset[U](
+      sqlContext,
+      MapGroups(func, groupingAttributes, logicalPlan))
+  }
+
+  def map[U](
+      f: JFunction2[K, JIterator[T], U],
+      encoder: Encoder[U]): Dataset[U] = {
+    map((key, data) => f.call(key, data.asJava))(encoder)
   }
 
   // To ensure valid overloading.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 799650a4f784f488c9fd9b654e3ce659f223b71e..2593b16b1c8d7bbb6b8ba51d9ac0b3c382f63073 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -356,7 +356,7 @@ case class AppendColumns[T, U](
  * being output.
  */
 case class MapGroups[K, T, U](
-    func: (K, Iterator[T]) => Iterator[U],
+    func: (K, Iterator[T]) => TraversableOnce[U],
     kEncoder: ExpressionEncoder[K],
     tEncoder: ExpressionEncoder[T],
     uEncoder: ExpressionEncoder[U],
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a9493d576d179d6040a401c629596963da7eca68..0d3b1a5af52c494bb93638f777d45d330c8f386b 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -170,15 +170,15 @@ public class JavaDatasetSuite implements Serializable {
       }
     }, e.INT());
 
-    Dataset<String> mapped = grouped.mapGroups(
-      new Function2<Integer, Iterator<String>, Iterator<String>>() {
+    Dataset<String> mapped = grouped.map(
+      new Function2<Integer, Iterator<String>, String>() {
         @Override
-        public Iterator<String> call(Integer key, Iterator<String> data) throws Exception {
+        public String call(Integer key, Iterator<String> data) throws Exception {
           StringBuilder sb = new StringBuilder(key.toString());
           while (data.hasNext()) {
             sb.append(data.next());
           }
-          return Collections.singletonList(sb.toString()).iterator();
+          return sb.toString();
         }
       },
       e.STRING());
@@ -224,15 +224,15 @@ public class JavaDatasetSuite implements Serializable {
     Dataset<String> ds = context.createDataset(data, e.STRING());
     GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT());
 
-    Dataset<String> mapped = grouped.mapGroups(
-      new Function2<Integer, Iterator<String>, Iterator<String>>() {
+    Dataset<String> mapped = grouped.map(
+      new Function2<Integer, Iterator<String>, String>() {
         @Override
-        public Iterator<String> call(Integer key, Iterator<String> data) throws Exception {
+        public String call(Integer key, Iterator<String> data) throws Exception {
           StringBuilder sb = new StringBuilder(key.toString());
           while (data.hasNext()) {
             sb.append(data.next());
           }
-          return Collections.singletonList(sb.toString()).iterator();
+          return sb.toString();
         }
       },
       e.STRING());
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index e3b0346f857d33d02cdfc07a150e207d44de6cea..fcf03f7180984a77e61219fa94bfd48c2c0dac0e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -88,16 +88,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
       0, 1)
   }
 
-  test("groupBy function, mapGroups") {
+  test("groupBy function, map") {
     val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS()
     val grouped = ds.groupBy(_ % 2)
-    val agged = grouped.mapGroups { case (g, iter) =>
+    val agged = grouped.map { case (g, iter) =>
       val name = if (g == 0) "even" else "odd"
-      Iterator((name, iter.size))
+      (name, iter.size)
     }
 
     checkAnswer(
       agged,
       ("even", 5), ("odd", 6))
   }
+
+  test("groupBy function, flatMap") {
+    val ds = Seq("a", "b", "c", "xyz", "hello").toDS()
+    val grouped = ds.groupBy(_.length)
+    val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) }
+
+    checkAnswer(
+      agged,
+      "1", "abc", "3", "xyz", "5", "hello")
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index d61e17edc64edaf6c65b5a03485343633f1fbe93..6f1174e6577e3a65bd0f6d3e8265ef57e5ee4804 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -198,60 +198,60 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       (1, 1))
   }
 
-  test("groupBy function, mapGroups") {
+  test("groupBy function, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy(v => (v._1, "word"))
-    val agged = grouped.mapGroups { case (g, iter) =>
-      Iterator((g._1, iter.map(_._2).sum))
-    }
+    val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) }
 
     checkAnswer(
       agged,
       ("a", 30), ("b", 3), ("c", 1))
   }
 
-  test("groupBy columns, mapGroups") {
+  test("groupBy function, fatMap") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    val grouped = ds.groupBy(v => (v._1, "word"))
+    val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) }
+
+    checkAnswer(
+      agged,
+      "a", "30", "b", "3", "c", "1")
+  }
+
+  test("groupBy columns, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy($"_1")
-    val agged = grouped.mapGroups { case (g, iter) =>
-      Iterator((g.getString(0), iter.map(_._2).sum))
-    }
+    val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) }
 
     checkAnswer(
       agged,
       ("a", 30), ("b", 3), ("c", 1))
   }
 
-  test("groupBy columns asKey, mapGroups") {
+  test("groupBy columns asKey, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy($"_1").asKey[String]
-    val agged = grouped.mapGroups { case (g, iter) =>
-      Iterator((g, iter.map(_._2).sum))
-    }
+    val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
 
     checkAnswer(
       agged,
       ("a", 30), ("b", 3), ("c", 1))
   }
 
-  test("groupBy columns asKey tuple, mapGroups") {
+  test("groupBy columns asKey tuple, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)]
-    val agged = grouped.mapGroups { case (g, iter) =>
-      Iterator((g, iter.map(_._2).sum))
-    }
+    val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
 
     checkAnswer(
       agged,
       (("a", 1), 30), (("b", 1), 3), (("c", 1), 1))
   }
 
-  test("groupBy columns asKey class, mapGroups") {
+  test("groupBy columns asKey class, map") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
     val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData]
-    val agged = grouped.mapGroups { case (g, iter) =>
-      Iterator((g, iter.map(_._2).sum))
-    }
+    val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) }
 
     checkAnswer(
       agged,