diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index fefe5a3953a6e77f9e551b1c35d3d5ff6994763c..0ab4c9016623e10ebf6bd0dee0a6fb12e63ea538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -230,6 +230,19 @@ object AppendColumns { encoderFor[U].namedExpressions, child) } + + def apply[T : Encoder, U : Encoder]( + func: T => U, + inputAttributes: Seq[Attribute], + child: LogicalPlan): AppendColumns = { + new AppendColumns( + func.asInstanceOf[Any => Any], + implicitly[Encoder[T]].clsTag.runtimeClass, + implicitly[Encoder[T]].schema, + UnresolvedDeserializer(encoderFor[T].deserializer, inputAttributes), + encoderFor[U].namedExpressions, + child) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 828eb94efe598a1f22cb3e6e9805c8cc16a21985..4cb0313aa90374d08cb1e615ba5fbec5df159f03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -66,6 +66,48 @@ class KeyValueGroupedDataset[K, V] private[sql]( dataAttributes, groupingAttributes) + /** + * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied + * to the data. The grouping key is unchanged by this. + * + * {{{ + * // Create values grouped by key from a Dataset[(K, V)] + * ds.groupByKey(_._1).mapValues(_._2) // Scala + * }}} + * + * @since 2.1.0 + */ + def mapValues[W : Encoder](func: V => W): KeyValueGroupedDataset[K, W] = { + val withNewData = AppendColumns(func, dataAttributes, logicalPlan) + val projected = Project(withNewData.newColumns ++ groupingAttributes, withNewData) + val executed = sparkSession.sessionState.executePlan(projected) + + new KeyValueGroupedDataset( + encoderFor[K], + encoderFor[W], + executed, + withNewData.newColumns, + groupingAttributes) + } + + /** + * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied + * to the data. The grouping key is unchanged by this. + * + * {{{ + * // Create Integer values grouped by String key from a Dataset<Tuple2<String, Integer>> + * Dataset<Tuple2<String, Integer>> ds = ...; + * KeyValueGroupedDataset<String, Integer> grouped = + * ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT()); // Java 8 + * }}} + * + * @since 2.1.0 + */ + def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = { + implicit val uEnc = encoder + mapValues { (v: V) => func.call(v) } + } + /** * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping * over the Dataset to extract the keys and then running a distinct operation on those. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 5fce9b4fe97ea668dcbec0847d6e632e41de1651..cc367acae2ba4359a8af7690186979c8841d7f36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -336,6 +336,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { "a", "30", "b", "3", "c", "1") } + test("groupBy function, mapValues, flatMap") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val keyValue = ds.groupByKey(_._1).mapValues(_._2) + val agged = keyValue.mapGroups { case (g, iter) => (g, iter.sum) } + checkDataset(agged, ("a", 30), ("b", 3), ("c", 1)) + + val keyValue1 = ds.groupByKey(t => (t._1, "key")).mapValues(t => (t._2, "value")) + val agged1 = keyValue1.mapGroups { case (g, iter) => (g._1, iter.map(_._1).sum) } + checkDataset(agged, ("a", 30), ("b", 3), ("c", 1)) + } + test("groupBy function, reduce") { val ds = Seq("abc", "xyz", "hello").toDS() val agged = ds.groupByKey(_.length).reduceGroups(_ + _)