From 94a9d11ed1f61205af8067bf17d14dc93935ddf8 Mon Sep 17 00:00:00 2001
From: Sean Zhong <seanzhong@databricks.com>
Date: Mon, 8 Aug 2016 22:20:54 +0800
Subject: [PATCH] [SPARK-16906][SQL] Adds auxiliary info like input class and
 input schema in TypedAggregateExpression

## What changes were proposed in this pull request?

This PR adds auxiliary info like input class and input schema in TypedAggregateExpression

## How was this patch tested?

Manual test.

Author: Sean Zhong <seanzhong@databricks.com>

Closes #14501 from clockfly/typed_aggregation.
---
 .../src/main/scala/org/apache/spark/sql/Column.scala     | 9 ++++++---
 .../src/main/scala/org/apache/spark/sql/Dataset.scala    | 4 ++--
 .../org/apache/spark/sql/KeyValueGroupedDataset.scala    | 2 +-
 .../org/apache/spark/sql/RelationalGroupedDataset.scala  | 2 +-
 .../execution/aggregate/TypedAggregateExpression.scala   | 4 ++++
 5 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index a46d1949e9..844ca7a8e9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -69,12 +69,15 @@ class TypedColumn[-T, U](
    * on a decoded object.
    */
   private[sql] def withInputType(
-      inputDeserializer: Expression,
+      inputEncoder: ExpressionEncoder[_],
       inputAttributes: Seq[Attribute]): TypedColumn[T, U] = {
-    val unresolvedDeserializer = UnresolvedDeserializer(inputDeserializer, inputAttributes)
+    val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes)
     val newExpr = expr transform {
       case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty =>
-        ta.copy(inputDeserializer = Some(unresolvedDeserializer))
+        ta.copy(
+          inputDeserializer = Some(unresolvedDeserializer),
+          inputClass = Some(inputEncoder.clsTag.runtimeClass),
+          inputSchema = Some(inputEncoder.schema))
     }
     new TypedColumn[T, U](newExpr, encoder)
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 9eef5cc5fe..c119df83b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1059,7 +1059,7 @@ class Dataset[T] private[sql](
   @Experimental
   def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
     implicit val encoder = c1.encoder
-    val project = Project(c1.withInputType(exprEnc.deserializer, logicalPlan.output).named :: Nil,
+    val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil,
       logicalPlan)
 
     if (encoder.flat) {
@@ -1078,7 +1078,7 @@ class Dataset[T] private[sql](
   protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
     val encoders = columns.map(_.encoder)
     val namedColumns =
-      columns.map(_.withInputType(exprEnc.deserializer, logicalPlan.output).named)
+      columns.map(_.withInputType(exprEnc, logicalPlan.output).named)
     val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan))
     new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders))
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index a6867a67ee..65a725f3d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -201,7 +201,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
   protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
     val encoders = columns.map(_.encoder)
     val namedColumns =
-      columns.map(_.withInputType(vExprEnc.deserializer, dataAttributes).named)
+      columns.map(_.withInputType(vExprEnc, dataAttributes).named)
     val keyColumn = if (kExprEnc.flat) {
       assert(groupingAttributes.length == 1)
       groupingAttributes.head
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 1aa5767038..7cfd1cdc7d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -219,7 +219,7 @@ class RelationalGroupedDataset protected[sql](
   def agg(expr: Column, exprs: Column*): DataFrame = {
     toDF((expr +: exprs).map {
       case typed: TypedColumn[_, _] =>
-        typed.withInputType(df.exprEnc.deserializer, df.logicalPlan.output).expr
+        typed.withInputType(df.exprEnc, df.logicalPlan.output).expr
       case c => c.expr
     })
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 2cdf4703a5..6f7f2f842c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -47,6 +47,8 @@ object TypedAggregateExpression {
     new TypedAggregateExpression(
       aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
       None,
+      None,
+      None,
       bufferSerializer,
       bufferDeserializer,
       outputEncoder.serializer,
@@ -62,6 +64,8 @@ object TypedAggregateExpression {
 case class TypedAggregateExpression(
     aggregator: Aggregator[Any, Any, Any],
     inputDeserializer: Option[Expression],
+    inputClass: Option[Class[_]],
+    inputSchema: Option[StructType],
     bufferSerializer: Seq[NamedExpression],
     bufferDeserializer: Expression,
     outputSerializer: Seq[Expression],
-- 
GitLab