diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 2f4d68d17943db7ca92c83ced8afef957cf6d606..eaeb010b0e4fa38c0d2f0a6bba8592027c8db2d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -33,10 +33,9 @@ import org.apache.spark.util.collection.OpenHashMap
  * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at
  * the given percentage(s) with value range in [0.0, 1.0].
  *
- * The operator is bound to the slower sort based aggregation path because the number of elements
- * and their partial order cannot be determined in advance. Therefore we have to store all the
- * elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory
- * Errors.
+ * Because the number of elements and their partial order cannot be determined in advance.
+ * Therefore we have to store all the elements in memory, and so notice that too many elements can
+ * cause GC paused and eventually OutOfMemory Errors.
  *
  * @param child child expression that produce numeric column value with `child.eval(inputRow)`
  * @param percentageExpression Expression that represents a single percentage value or an array of
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index b176e2a128f43923e580e7b60de35f920fcf4fcf..411f058510ca7226202458593f775ea9671413be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
 import scala.collection.generic.Growable
 import scala.collection.mutable
 
@@ -27,14 +29,12 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
 
 /**
- * The Collect aggregate function collects all seen expression values into a list of values.
+ * A base class for collect_list and collect_set aggregate functions.
  *
- * The operator is bound to the slower sort based aggregation path because the number of
- * elements (and their memory usage) can not be determined in advance. This also means that the
- * collected elements are stored on heap, and that too many elements can cause GC pauses and
- * eventually Out of Memory Errors.
+ * We have to store all the collected elements in memory, and so notice that too many elements
+ * can cause GC paused and eventually OutOfMemory Errors.
  */
-abstract class Collect extends ImperativeAggregate {
+abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T] {
 
   val child: Expression
 
@@ -44,40 +44,44 @@ abstract class Collect extends ImperativeAggregate {
 
   override def dataType: DataType = ArrayType(child.dataType)
 
-  override def supportsPartial: Boolean = false
-
-  override def aggBufferAttributes: Seq[AttributeReference] = Nil
-
-  override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
-
-  override def inputAggBufferAttributes: Seq[AttributeReference] = Nil
-
   // Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the
   // actual order of input rows.
   override def deterministic: Boolean = false
 
-  protected[this] val buffer: Growable[Any] with Iterable[Any]
-
-  override def initialize(b: InternalRow): Unit = {
-    buffer.clear()
-  }
+  override def update(buffer: T, input: InternalRow): T = {
+    val value = child.eval(input)
 
-  override def update(b: InternalRow, input: InternalRow): Unit = {
     // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here.
     // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
-    val value = child.eval(input)
     if (value != null) {
       buffer += value
     }
+    buffer
   }
 
-  override def merge(buffer: InternalRow, input: InternalRow): Unit = {
-    sys.error("Collect cannot be used in partial aggregations.")
+  override def merge(buffer: T, other: T): T = {
+    buffer ++= other
   }
 
-  override def eval(input: InternalRow): Any = {
+  override def eval(buffer: T): Any = {
     new GenericArrayData(buffer.toArray)
   }
+
+  private lazy val projection = UnsafeProjection.create(
+    Array[DataType](ArrayType(elementType = child.dataType, containsNull = false)))
+  private lazy val row = new UnsafeRow(1)
+
+  override def serialize(obj: T): Array[Byte] = {
+    val array = new GenericArrayData(obj.toArray)
+    projection.apply(InternalRow.apply(array)).getBytes()
+  }
+
+  override def deserialize(bytes: Array[Byte]): T = {
+    val buffer = createAggregationBuffer()
+    row.pointTo(bytes, bytes.length)
+    row.getArray(0).foreach(child.dataType, (_, x: Any) => buffer += x)
+    buffer
+  }
 }
 
 /**
@@ -88,7 +92,7 @@ abstract class Collect extends ImperativeAggregate {
 case class CollectList(
     child: Expression,
     mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0) extends Collect {
+    inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {
 
   def this(child: Expression) = this(child, 0, 0)
 
@@ -98,9 +102,9 @@ case class CollectList(
   override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-  override def prettyName: String = "collect_list"
+  override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
 
-  override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
+  override def prettyName: String = "collect_list"
 }
 
 /**
@@ -111,7 +115,7 @@ case class CollectList(
 case class CollectSet(
     child: Expression,
     mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0) extends Collect {
+    inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] {
 
   def this(child: Expression) = this(child, 0, 0)
 
@@ -131,5 +135,5 @@ case class CollectSet(
 
   override def prettyName: String = "collect_set"
 
-  override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty
+  override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 8e63fba14ce541923b5efd46fb79406b81633bca..ccd4ae6c2d845c6899d9524a99682290ed0b554e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -458,7 +458,9 @@ abstract class DeclarativeAggregate
  * instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation
  * buffer's storage format, which is not supported by hash based aggregation. Hash based
  * aggregation only support aggregation buffer of mutable types (like LongType, IntType that have
- * fixed length and can be mutated in place in UnsafeRow)
+ * fixed length and can be mutated in place in UnsafeRow).
+ * NOTE: The newly added ObjectHashAggregateExec supports TypedImperativeAggregate functions in
+ * hash based aggregation under some constraints.
  */
 abstract class TypedImperativeAggregate[T] extends ImperativeAggregate {
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
index 0b973c3b659cf3d74b02253a276673820ada2e18..5c1faaecdb548e627e4d3ba8ff9dd71d3c0982fc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
@@ -59,15 +59,6 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
     comparePlans(input, rewrite)
   }
 
-  test("single distinct group with non-partial aggregates") {
-    val input = testRelation
-      .groupBy('a, 'd)(
-        countDistinct('e, 'c).as('agg1),
-        CollectSet('b).toAggregateExpression().as('agg2))
-      .analyze
-    checkRewrite(RewriteDistinctAggregates(input))
-  }
-
   test("multiple distinct groups") {
     val input = testRelation
       .groupBy('a)(countDistinct('b, 'c), countDistinct('d))