diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index fefe5a3953a6e77f9e551b1c35d3d5ff6994763c..0ab4c9016623e10ebf6bd0dee0a6fb12e63ea538 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -230,6 +230,19 @@ object AppendColumns {
       encoderFor[U].namedExpressions,
       child)
   }
+
+  def apply[T : Encoder, U : Encoder](
+      func: T => U,
+      inputAttributes: Seq[Attribute],
+      child: LogicalPlan): AppendColumns = {
+    new AppendColumns(
+      func.asInstanceOf[Any => Any],
+      implicitly[Encoder[T]].clsTag.runtimeClass,
+      implicitly[Encoder[T]].schema,
+      UnresolvedDeserializer(encoderFor[T].deserializer, inputAttributes),
+      encoderFor[U].namedExpressions,
+      child)
+  }
 }
 
 /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 828eb94efe598a1f22cb3e6e9805c8cc16a21985..4cb0313aa90374d08cb1e615ba5fbec5df159f03 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -66,6 +66,48 @@ class KeyValueGroupedDataset[K, V] private[sql](
       dataAttributes,
       groupingAttributes)
 
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied
+   * to the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create values grouped by key from a Dataset[(K, V)]
+   *   ds.groupByKey(_._1).mapValues(_._2) // Scala
+   * }}}
+   *
+   * @since 2.1.0
+   */
+  def mapValues[W : Encoder](func: V => W): KeyValueGroupedDataset[K, W] = {
+    val withNewData = AppendColumns(func, dataAttributes, logicalPlan)
+    val projected = Project(withNewData.newColumns ++ groupingAttributes, withNewData)
+    val executed = sparkSession.sessionState.executePlan(projected)
+
+    new KeyValueGroupedDataset(
+      encoderFor[K],
+      encoderFor[W],
+      executed,
+      withNewData.newColumns,
+      groupingAttributes)
+  }
+
+  /**
+   * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied
+   * to the data. The grouping key is unchanged by this.
+   *
+   * {{{
+   *   // Create Integer values grouped by String key from a Dataset<Tuple2<String, Integer>>
+   *   Dataset<Tuple2<String, Integer>> ds = ...;
+   *   KeyValueGroupedDataset<String, Integer> grouped =
+   *     ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT()); // Java 8
+   * }}}
+   *
+   * @since 2.1.0
+   */
+  def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = {
+    implicit val uEnc = encoder
+    mapValues { (v: V) => func.call(v) }
+  }
+
   /**
    * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping
    * over the Dataset to extract the keys and then running a distinct operation on those.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 5fce9b4fe97ea668dcbec0847d6e632e41de1651..cc367acae2ba4359a8af7690186979c8841d7f36 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -336,6 +336,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       "a", "30", "b", "3", "c", "1")
   }
 
+  test("groupBy function, mapValues, flatMap") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    val keyValue = ds.groupByKey(_._1).mapValues(_._2)
+    val agged = keyValue.mapGroups { case (g, iter) => (g, iter.sum) }
+    checkDataset(agged, ("a", 30), ("b", 3), ("c", 1))
+
+    val keyValue1 = ds.groupByKey(t => (t._1, "key")).mapValues(t => (t._2, "value"))
+    val agged1 = keyValue1.mapGroups { case (g, iter) => (g._1, iter.map(_._1).sum) }
+    checkDataset(agged, ("a", 30), ("b", 3), ("c", 1))
+  }
+
   test("groupBy function, reduce") {
     val ds = Seq("abc", "xyz", "hello").toDS()
     val agged = ds.groupByKey(_.length).reduceGroups(_ + _)