diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index a46d1949e94ae5d203809670d07f801d014f8d56..844ca7a8e99caa20d8d35fb3846476ac295a53e0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -69,12 +69,15 @@ class TypedColumn[-T, U](
    * on a decoded object.
    */
   private[sql] def withInputType(
-      inputDeserializer: Expression,
+      inputEncoder: ExpressionEncoder[_],
       inputAttributes: Seq[Attribute]): TypedColumn[T, U] = {
-    val unresolvedDeserializer = UnresolvedDeserializer(inputDeserializer, inputAttributes)
+    val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes)
     val newExpr = expr transform {
       case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty =>
-        ta.copy(inputDeserializer = Some(unresolvedDeserializer))
+        ta.copy(
+          inputDeserializer = Some(unresolvedDeserializer),
+          inputClass = Some(inputEncoder.clsTag.runtimeClass),
+          inputSchema = Some(inputEncoder.schema))
     }
     new TypedColumn[T, U](newExpr, encoder)
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 9eef5cc5fe42abba7ff194138d06ddd5c3296164..c119df83b3d71af0c1594c6100c1b67ca03116e2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1059,7 +1059,7 @@ class Dataset[T] private[sql](
   @Experimental
   def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
     implicit val encoder = c1.encoder
-    val project = Project(c1.withInputType(exprEnc.deserializer, logicalPlan.output).named :: Nil,
+    val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil,
       logicalPlan)
 
     if (encoder.flat) {
@@ -1078,7 +1078,7 @@ class Dataset[T] private[sql](
   protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
     val encoders = columns.map(_.encoder)
     val namedColumns =
-      columns.map(_.withInputType(exprEnc.deserializer, logicalPlan.output).named)
+      columns.map(_.withInputType(exprEnc, logicalPlan.output).named)
     val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan))
     new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders))
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index a6867a67eeade4b7a5e575d584dbfe67e0aa0307..65a725f3d4a81a659766f7dade99ddda68afcbee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -201,7 +201,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
   protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
     val encoders = columns.map(_.encoder)
     val namedColumns =
-      columns.map(_.withInputType(vExprEnc.deserializer, dataAttributes).named)
+      columns.map(_.withInputType(vExprEnc, dataAttributes).named)
     val keyColumn = if (kExprEnc.flat) {
       assert(groupingAttributes.length == 1)
       groupingAttributes.head
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 1aa5767038d539e60045ea95b8c8d07f557029cb..7cfd1cdc7d5d178f0b9531b4c56e7035a5a70a90 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -219,7 +219,7 @@ class RelationalGroupedDataset protected[sql](
   def agg(expr: Column, exprs: Column*): DataFrame = {
     toDF((expr +: exprs).map {
       case typed: TypedColumn[_, _] =>
-        typed.withInputType(df.exprEnc.deserializer, df.logicalPlan.output).expr
+        typed.withInputType(df.exprEnc, df.logicalPlan.output).expr
       case c => c.expr
     })
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 2cdf4703a5d7a45b1bef04c5461e125ce2bfb60a..6f7f2f842c4262f8a73481c664586ea60a507e4a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -47,6 +47,8 @@ object TypedAggregateExpression {
     new TypedAggregateExpression(
       aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
       None,
+      None,
+      None,
       bufferSerializer,
       bufferDeserializer,
       outputEncoder.serializer,
@@ -62,6 +64,8 @@ object TypedAggregateExpression {
 case class TypedAggregateExpression(
     aggregator: Aggregator[Any, Any, Any],
     inputDeserializer: Option[Expression],
+    inputClass: Option[Class[_]],
+    inputSchema: Option[StructType],
     bufferSerializer: Seq[NamedExpression],
     bufferDeserializer: Expression,
     outputSerializer: Seq[Expression],