diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index 0011096d4a47bcc006a2cd0e34a5ecb047c9bcbf..f7c22dddb8cc0dabb99fa260fd31a730a7577321 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -17,12 +17,13 @@
 
 package org.apache.spark.util.sketch;
 
+import java.io.ByteArrayInputStream;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
 
 /**
- * A Count-min sketch is a probabilistic data structure used for summarizing streams of data in
+ * A Count-min sketch is a probabilistic data structure used for cardinality estimation using
  * sub-linear space.  Currently, supported data types include:
  * <ul>
  *   <li>{@link Byte}</li>
@@ -30,10 +31,6 @@ import java.io.OutputStream;
  *   <li>{@link Integer}</li>
  *   <li>{@link Long}</li>
  *   <li>{@link String}</li>
- *   <li>{@link Float}</li>
- *   <li>{@link Double}</li>
- *   <li>{@link java.math.BigDecimal}</li>
- *   <li>{@link Boolean}</li>
  * </ul>
  * A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters:
  * <ol>
@@ -177,6 +174,11 @@ public abstract class CountMinSketch {
    */
   public abstract void writeTo(OutputStream out) throws IOException;
 
+  /**
+   * Serializes this {@link CountMinSketch} and returns the serialized form.
+   */
+  public abstract byte[] toByteArray() throws IOException;
+
   /**
    * Reads in a {@link CountMinSketch} from an input stream. It is the caller's responsibility to
    * close the stream.
@@ -185,6 +187,16 @@ public abstract class CountMinSketch {
     return CountMinSketchImpl.readFrom(in);
   }
 
+  /**
+   * Reads in a {@link CountMinSketch} from a byte array.
+   */
+  public static CountMinSketch readFrom(byte[] bytes) throws IOException {
+    InputStream in = new ByteArrayInputStream(bytes);
+    CountMinSketch cms = readFrom(in);
+    in.close();
+    return cms;
+  }
+
   /**
    * Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, and random
    * {@code seed}.
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index 94ab3a98cb65af942039fe74d18322d582f71432..045fec33a282a18d401b55b4348d23f5fa6c599f 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -17,15 +17,7 @@
 
 package org.apache.spark.util.sketch;
 
-import java.io.DataInputStream;
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import java.io.OutputStream;
-import java.io.Serializable;
-import java.math.BigDecimal;
+import java.io.*;
 import java.util.Arrays;
 import java.util.Random;
 
@@ -153,16 +145,8 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
   public void add(Object item, long count) {
     if (item instanceof String) {
       addString((String) item, count);
-    } else if (item instanceof BigDecimal) {
-      addString(((BigDecimal) item).toString(), count);
     } else if (item instanceof byte[]) {
       addBinary((byte[]) item, count);
-    } else if (item instanceof Float) {
-      addLong(Float.floatToIntBits((Float) item), count);
-    } else if (item instanceof Double) {
-      addLong(Double.doubleToLongBits((Double) item), count);
-    } else if (item instanceof Boolean) {
-      addLong(((Boolean) item) ? 1L : 0L, count);
     } else {
       addLong(Utils.integralToLong(item), count);
     }
@@ -227,6 +211,10 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
     return ((int) hash) % width;
   }
 
+  private static int[] getHashBuckets(String key, int hashCount, int max) {
+    return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
+  }
+
   private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
     int[] result = new int[hashCount];
     int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
@@ -240,18 +228,9 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
   @Override
   public long estimateCount(Object item) {
     if (item instanceof String) {
-      return estimateCountForBinaryItem(Utils.getBytesFromUTF8String((String) item));
-    } else if (item instanceof BigDecimal) {
-      return estimateCountForBinaryItem(
-        Utils.getBytesFromUTF8String(((BigDecimal) item).toString()));
+      return estimateCountForStringItem((String) item);
     } else if (item instanceof byte[]) {
       return estimateCountForBinaryItem((byte[]) item);
-    } else if (item instanceof Float) {
-      return estimateCountForLongItem(Float.floatToIntBits((Float) item));
-    } else if (item instanceof Double) {
-      return estimateCountForLongItem(Double.doubleToLongBits((Double) item));
-    } else if (item instanceof Boolean) {
-      return estimateCountForLongItem(((Boolean) item) ? 1L : 0L);
     } else {
       return estimateCountForLongItem(Utils.integralToLong(item));
     }
@@ -265,6 +244,15 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
     return res;
   }
 
+  private long estimateCountForStringItem(String item) {
+    long res = Long.MAX_VALUE;
+    int[] buckets = getHashBuckets(item, depth, width);
+    for (int i = 0; i < depth; ++i) {
+      res = Math.min(res, table[i][buckets[i]]);
+    }
+    return res;
+  }
+
   private long estimateCountForBinaryItem(byte[] item) {
     long res = Long.MAX_VALUE;
     int[] buckets = getHashBuckets(item, depth, width);
@@ -332,6 +320,14 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
     }
   }
 
+  @Override
+  public byte[] toByteArray() throws IOException {
+    ByteArrayOutputStream out = new ByteArrayOutputStream();
+    writeTo(out);
+    out.close();
+    return out.toByteArray();
+  }
+
   public static CountMinSketchImpl readFrom(InputStream in) throws IOException {
     CountMinSketchImpl sketch = new CountMinSketchImpl();
     sketch.readFrom0(in);
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
index 2c358fcee46ce5a7e284683a914b4746a2fb0cdc..174eb01986c4f1e0562c0b7af13b0315128c5178 100644
--- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.util.sketch
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
-import java.nio.charset.StandardCharsets
 
 import scala.reflect.ClassTag
 import scala.util.Random
@@ -26,9 +25,9 @@ import scala.util.Random
 import org.scalatest.FunSuite // scalastyle:ignore funsuite
 
 class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
-  private val epsOfTotalCount = 0.0001
+  private val epsOfTotalCount = 0.01
 
-  private val confidence = 0.99
+  private val confidence = 0.9
 
   private val seed = 42
 
@@ -45,12 +44,6 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
   }
 
   def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
-    def getProbeItem(item: T): Any = item match {
-      // Use a string to represent the content of an array of bytes
-      case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
-      case i => identity(i)
-    }
-
     test(s"accuracy - $typeName") {
       // Uses fixed seed to ensure reproducible test execution
       val r = new Random(31)
@@ -63,7 +56,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
 
       val exactFreq = {
         val sampledItems = sampledItemIndices.map(allItems)
-        sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
+        sampledItems.groupBy(identity).mapValues(_.length.toLong)
       }
 
       val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
@@ -74,12 +67,12 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
 
       val probCorrect = {
         val numErrors = allItems.map { item =>
-          val count = exactFreq.getOrElse(getProbeItem(item), 0L)
+          val count = exactFreq.getOrElse(item, 0L)
           val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems
           if (ratio > epsOfTotalCount) 1 else 0
         }.sum
 
-        1D - numErrors.toDouble / numAllItems
+        1.0 - (numErrors.toDouble / numAllItems)
       }
 
       assert(
@@ -96,9 +89,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
 
       val numToMerge = 5
       val numItemsPerSketch = 100000
-      val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) {
-        itemGenerator(r)
-      }
+      val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { itemGenerator(r) }
 
       val sketches = perSketchItems.map { items =>
         val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
@@ -113,11 +104,8 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
       val mergedSketch = sketches.reduce(_ mergeInPlace _)
       checkSerDe(mergedSketch)
 
-      val expectedSketch = {
-        val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
-        perSketchItems.foreach(_.foreach(sketch.add))
-        sketch
-      }
+      val expectedSketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+      perSketchItems.foreach(_.foreach(expectedSketch.add))
 
       perSketchItems.foreach {
         _.foreach { item =>
@@ -142,17 +130,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
 
   testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }
 
-  testItemType[Float]("Float") { _.nextFloat() }
-
-  testItemType[Double]("Double") { _.nextDouble() }
-
-  testItemType[java.math.BigDecimal]("Decimal") { r => new java.math.BigDecimal(r.nextDouble()) }
-
-  testItemType[Boolean]("Boolean") { _.nextBoolean() }
-
-  testItemType[Array[Byte]]("Binary") { r =>
-    Utils.getBytesFromUTF8String(r.nextString(r.nextInt(20)))
-  }
+  testItemType[Array[Byte]]("Byte array") { r => r.nextString(r.nextInt(60)).getBytes }
 
   test("incompatible merge") {
     intercept[IncompatibleMergeException] {
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 4995af034f654e0faa6c9ed4710ecc4ad5b61314..b113bbf803d950cea8327a89dee27fab002b5ffb 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -34,6 +34,11 @@ import com.typesafe.tools.mima.core.ProblemFilters._
  */
 object MimaExcludes {
 
+  lazy val v22excludes = v21excludes ++ Seq(
+    // [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray")
+  )
+
   // Exclude rules for 2.1.x
   lazy val v21excludes = v20excludes ++ {
     Seq(
@@ -912,7 +917,8 @@ object MimaExcludes {
   }
 
   def excludes(version: String) = version match {
-    case v if v.startsWith("2.1") => v21excludes
+    case v if v.startsWith("2.2") => v22excludes
+    case v if v.startsWith("2.1") => v22excludes  // TODO: Update this when we bump version to 2.2
     case v if v.startsWith("2.0") => v20excludes
     case _ => Seq()
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
index f5f185f2c54287180adf24e7a1aae3ef103cbc64..612c19831f0b2bb22c71e222f96cce8dfdbd5382 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
-
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
@@ -42,9 +40,9 @@ import org.apache.spark.util.sketch.CountMinSketch
 @ExpressionDescription(
   usage = """
     _FUNC_(col, eps, confidence, seed) - Returns a count-min sketch of a column with the given esp,
-      confidence and seed. The result is an array of bytes, which should be deserialized to a
-      `CountMinSketch` before usage. `CountMinSketch` is useful for equality predicates and join
-      size estimation.
+      confidence and seed. The result is an array of bytes, which can be deserialized to a
+      `CountMinSketch` before usage. Count-min sketch is a probabilistic data structure used for
+      cardinality estimation using sub-linear space.
   """)
 case class CountMinSketchAgg(
     child: Expression,
@@ -75,13 +73,13 @@ case class CountMinSketchAgg(
     } else if (!epsExpression.foldable || !confidenceExpression.foldable ||
       !seedExpression.foldable) {
       TypeCheckFailure(
-        "The eps, confidence or seed provided must be a literal or constant foldable")
+        "The eps, confidence or seed provided must be a literal or foldable")
     } else if (epsExpression.eval() == null || confidenceExpression.eval() == null ||
       seedExpression.eval() == null) {
       TypeCheckFailure("The eps, confidence or seed provided should not be null")
-    } else if (eps <= 0D) {
+    } else if (eps <= 0.0) {
       TypeCheckFailure(s"Relative error must be positive (current value = $eps)")
-    } else if (confidence <= 0D || confidence >= 1D) {
+    } else if (confidence <= 0.0 || confidence >= 1.0) {
       TypeCheckFailure(s"Confidence must be within range (0.0, 1.0) (current value = $confidence)")
     } else {
       TypeCheckSuccess
@@ -97,9 +95,6 @@ case class CountMinSketchAgg(
     // Ignore empty rows
     if (value != null) {
       child.dataType match {
-        // `Decimal` and `UTF8String` are internal types in spark sql, we need to convert them
-        // into acceptable types for `CountMinSketch`.
-        case DecimalType() => buffer.add(value.asInstanceOf[Decimal].toJavaBigDecimal)
         // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
         // instead of `addString` to avoid unnecessary conversion.
         case StringType => buffer.addBinary(value.asInstanceOf[UTF8String].getBytes)
@@ -115,14 +110,11 @@ case class CountMinSketchAgg(
   override def eval(buffer: CountMinSketch): Any = serialize(buffer)
 
   override def serialize(buffer: CountMinSketch): Array[Byte] = {
-    val out = new ByteArrayOutputStream()
-    buffer.writeTo(out)
-    out.toByteArray
+    buffer.toByteArray
   }
 
   override def deserialize(storageFormat: Array[Byte]): CountMinSketch = {
-    val in = new ByteArrayInputStream(storageFormat)
-    CountMinSketch.readFrom(in)
+    CountMinSketch.readFrom(storageFormat)
   }
 
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CountMinSketchAgg =
@@ -132,8 +124,7 @@ case class CountMinSketchAgg(
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
   override def inputTypes: Seq[AbstractDataType] = {
-    Seq(TypeCollection(NumericType, StringType, DateType, TimestampType, BooleanType, BinaryType),
-      DoubleType, DoubleType, IntegerType)
+    Seq(TypeCollection(IntegralType, StringType, BinaryType), DoubleType, DoubleType, IntegerType)
   }
 
   override def nullable: Boolean = false
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
index 8456e244609bc988e49f1801883f82d9f0f923f2..fcb370ae8460f3fb2a103f8e135676b61b9f5719 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
@@ -86,7 +86,7 @@ class ApproximatePercentileSuite extends SparkFunSuite {
       (headBufferSize + bufferSize) * 2
     }
 
-    val sizePerInputs = Seq(100, 1000, 10000, 100000, 1000000, 10000000).map { count =>
+    Seq(100, 1000, 10000, 100000, 1000000, 10000000).foreach { count =>
       val buffer = new PercentileDigest(relativeError)
       // Worst case, data is linear sorted
       (0 until count).foreach(buffer.add(_))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
index 6e08e29c0449f6280971e5e3fc83cf6935966766..10479630f3f9950e0e49e7a3ed623374f7314b26 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
@@ -17,199 +17,114 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import java.io.ByteArrayInputStream
-import java.nio.charset.StandardCharsets
+import java.{lang => jl}
 
-import scala.reflect.ClassTag
 import scala.util.Random
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Cast, GenericInternalRow, Literal}
-import org.apache.spark.sql.types.{DecimalType, _}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.sketch.CountMinSketch
 
+/**
+ * Unit test suite for the count-min sketch SQL aggregate funciton [[CountMinSketchAgg]].
+ */
 class CountMinSketchAggSuite extends SparkFunSuite {
   private val childExpression = BoundReference(0, IntegerType, nullable = true)
   private val epsOfTotalCount = 0.0001
   private val confidence = 0.99
   private val seed = 42
-
-  test("serialize and de-serialize") {
-    // Check empty serialize and de-serialize
-    val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), Literal(confidence),
-      Literal(seed))
-    val buffer = CountMinSketch.create(epsOfTotalCount, confidence, seed)
-    assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
-
-    // Check non-empty serialize and de-serialize
-    val random = new Random(31)
-    (0 until 10000).map(_ => random.nextInt(100)).foreach { value =>
-      buffer.add(value)
-    }
-    assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
+  private val rand = new Random(seed)
+
+  /** Creates a count-min sketch aggregate expression, using the child expression defined above. */
+  private def cms(eps: jl.Double, confidence: jl.Double, seed: jl.Integer): CountMinSketchAgg = {
+    new CountMinSketchAgg(
+      child = childExpression,
+      epsExpression = Literal(eps, DoubleType),
+      confidenceExpression = Literal(confidence, DoubleType),
+      seedExpression = Literal(seed, IntegerType))
   }
 
-  def testHighLevelInterface[T: ClassTag](
-      dataType: DataType,
-      sampledItemIndices: Array[Int],
-      allItems: Array[T],
-      exactFreq: Map[Any, Long]): Any = {
-    test(s"high level interface, update, merge, eval... - $dataType") {
+  /**
+   * Creates a new test case that compares our aggregate function with a reference implementation
+   * (using the underlying [[CountMinSketch]]).
+   *
+   * This works by splitting the items into two separate groups, aggregates them, and then merges
+   * the two groups back (to emulate partial aggregation), and then compares the result with
+   * that generated by [[CountMinSketch]] directly. This assumes insertion order does not impact
+   * the result in count-min sketch.
+   */
+  private def testDataType[T](dataType: DataType, items: Seq[T]): Unit = {
+    test("test data type " + dataType) {
       val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true),
         Literal(epsOfTotalCount), Literal(confidence), Literal(seed))
       assert(!agg.nullable)
 
-      val group1 = 0 until sampledItemIndices.length / 2
-      val group1Buffer = agg.createAggregationBuffer()
-      group1.foreach { index =>
-        val input = InternalRow(allItems(sampledItemIndices(index)))
-        agg.update(group1Buffer, input)
+      val (seq1, seq2) = items.splitAt(items.size / 2)
+      val buf1 = addToAggregateBuffer(agg, seq1)
+      val buf2 = addToAggregateBuffer(agg, seq2)
+
+      val sketch = agg.createAggregationBuffer()
+      agg.merge(sketch, buf1)
+      agg.merge(sketch, buf2)
+
+      // Validate cardinality estimation against reference implementation.
+      val referenceSketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+      items.foreach { item =>
+        referenceSketch.add(item match {
+          case u: UTF8String => u.getBytes
+          case _ => item
+        })
       }
 
-      val group2 = sampledItemIndices.length / 2 until sampledItemIndices.length
-      val group2Buffer = agg.createAggregationBuffer()
-      group2.foreach { index =>
-        val input = InternalRow(allItems(sampledItemIndices(index)))
-        agg.update(group2Buffer, input)
+      items.foreach { item =>
+        withClue(s"For item $item") {
+          val itemToTest = item match {
+            case u: UTF8String => u.getBytes
+            case _ => item
+          }
+          assert(referenceSketch.estimateCount(itemToTest) == sketch.estimateCount(itemToTest))
+        }
       }
-
-      var mergeBuffer = agg.createAggregationBuffer()
-      agg.merge(mergeBuffer, group1Buffer)
-      agg.merge(mergeBuffer, group2Buffer)
-      checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
-
-      // Merge in a different order
-      mergeBuffer = agg.createAggregationBuffer()
-      agg.merge(mergeBuffer, group2Buffer)
-      agg.merge(mergeBuffer, group1Buffer)
-      checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
-
-      // Merge with an empty partition
-      val emptyBuffer = agg.createAggregationBuffer()
-      agg.merge(mergeBuffer, emptyBuffer)
-      checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
     }
-  }
-
-  def testLowLevelInterface[T: ClassTag](
-      dataType: DataType,
-      sampledItemIndices: Array[Int],
-      allItems: Array[T],
-      exactFreq: Map[Any, Long]): Any = {
-    test(s"low level interface, update, merge, eval... - ${dataType.typeName}") {
-      val inputAggregationBufferOffset = 1
-      val mutableAggregationBufferOffset = 2
 
-      // Phase one, partial mode aggregation
-      val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true),
-        Literal(epsOfTotalCount), Literal(confidence), Literal(seed))
-        .withNewInputAggBufferOffset(inputAggregationBufferOffset)
-        .withNewMutableAggBufferOffset(mutableAggregationBufferOffset)
-
-      val mutableAggBuffer = new GenericInternalRow(
-        new Array[Any](mutableAggregationBufferOffset + 1))
-      agg.initialize(mutableAggBuffer)
-
-      sampledItemIndices.foreach { i =>
-        agg.update(mutableAggBuffer, InternalRow(allItems(i)))
-      }
-      agg.serializeAggregateBufferInPlace(mutableAggBuffer)
-
-      // Serialize the aggregation buffer
-      val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset)
-      val inputAggBuffer = new GenericInternalRow(Array[Any](null, serialized))
-
-      // Phase 2: final mode aggregation
-      // Re-initialize the aggregation buffer
-      agg.initialize(mutableAggBuffer)
-      agg.merge(mutableAggBuffer, inputAggBuffer)
-      checkResult(agg.eval(mutableAggBuffer), allItems, exactFreq)
+    def addToAggregateBuffer[T](agg: CountMinSketchAgg, items: Seq[T]): CountMinSketch = {
+      val buf = agg.createAggregationBuffer()
+      items.foreach { item => agg.update(buf, InternalRow(item)) }
+      buf
     }
   }
 
-  private def checkResult[T: ClassTag](
-      result: Any,
-      data: Array[T],
-      exactFreq: Map[Any, Long]): Unit = {
-    result match {
-      case bytesData: Array[Byte] =>
-        val in = new ByteArrayInputStream(bytesData)
-        val cms = CountMinSketch.readFrom(in)
-        val probCorrect = {
-          val numErrors = data.map { i =>
-            val count = exactFreq.getOrElse(getProbeItem(i), 0L)
-            val item = i match {
-              case dec: Decimal => dec.toJavaBigDecimal
-              case str: UTF8String => str.getBytes
-              case _ => i
-            }
-            val ratio = (cms.estimateCount(item) - count).toDouble / data.length
-            if (ratio > epsOfTotalCount) 1 else 0
-          }.sum
+  testDataType[Byte](ByteType, Seq.fill(100) { rand.nextInt(10).toByte })
 
-          1D - numErrors.toDouble / data.length
-        }
+  testDataType[Short](ShortType, Seq.fill(100) { rand.nextInt(10).toShort })
 
-        assert(
-          probCorrect > confidence,
-          s"Confidence not reached: required $confidence, reached $probCorrect"
-        )
-      case _ => fail("unexpected return type")
-    }
-  }
+  testDataType[Int](IntegerType, Seq.fill(100) { rand.nextInt(10) })
 
-  private def getProbeItem[T: ClassTag](item: T): Any = item match {
-    // Use a string to represent the content of an array of bytes
-    case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
-    case i => identity(i)
-  }
+  testDataType[Long](LongType, Seq.fill(100) { rand.nextInt(10) })
 
-  def testItemType[T: ClassTag](dataType: DataType)(itemGenerator: Random => T): Unit = {
-    // Uses fixed seed to ensure reproducible test execution
-    val r = new Random(31)
+  testDataType[UTF8String](StringType, Seq.fill(100) { UTF8String.fromString(rand.nextString(1)) })
 
-    val numAllItems = 1000000
-    val allItems = Array.fill(numAllItems)(itemGenerator(r))
+  testDataType[Array[Byte]](BinaryType, Seq.fill(100) { rand.nextString(1).getBytes() })
 
-    val numSamples = numAllItems / 10
-    val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
+  test("serialize and de-serialize") {
+    // Check empty serialize and de-serialize
+    val agg = cms(epsOfTotalCount, confidence, seed)
+    val buffer = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+    assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
 
-    val exactFreq = {
-      val sampledItems = sampledItemIndices.map(allItems)
-      sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
+    // Check non-empty serialize and de-serialize
+    val random = new Random(31)
+    for (i <- 0 until 10) {
+      buffer.add(random.nextInt(100))
     }
-
-    testLowLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq)
-    testHighLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq)
-  }
-
-  testItemType[Byte](ByteType) { _.nextInt().toByte }
-
-  testItemType[Short](ShortType) { _.nextInt().toShort }
-
-  testItemType[Int](IntegerType) { _.nextInt() }
-
-  testItemType[Long](LongType) { _.nextLong() }
-
-  testItemType[UTF8String](StringType) { r => UTF8String.fromString(r.nextString(r.nextInt(20))) }
-
-  testItemType[Float](FloatType) { _.nextFloat() }
-
-  testItemType[Double](DoubleType) { _.nextDouble() }
-
-  testItemType[Decimal](new DecimalType()) { r => Decimal(r.nextDouble()) }
-
-  testItemType[Boolean](BooleanType) { _.nextBoolean() }
-
-  testItemType[Array[Byte]](BinaryType) { r =>
-    r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8)
+    assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
   }
 
-
-  test("fails analysis if eps, confidence or seed provided is not a literal or constant foldable") {
+  test("fails analysis if eps, confidence or seed provided is not foldable") {
     val wrongEps = new CountMinSketchAgg(
       childExpression,
       epsExpression = AttributeReference("a", DoubleType)(),
@@ -227,88 +142,55 @@ class CountMinSketchAggSuite extends SparkFunSuite {
       seedExpression = AttributeReference("c", IntegerType)())
 
     Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg =>
-      assertEqual(
-        wrongAgg.checkInputDataTypes(),
-        TypeCheckFailure(
-          "The eps, confidence or seed provided must be a literal or constant foldable")
-      )
+      assertResult(
+        TypeCheckFailure("The eps, confidence or seed provided must be a literal or foldable")) {
+        wrongAgg.checkInputDataTypes()
+      }
     }
   }
 
   test("fails analysis if parameters are invalid") {
     // parameters are null
-    val wrongEps = new CountMinSketchAgg(
-      childExpression,
-      epsExpression = Cast(Literal(null), DoubleType),
-      confidenceExpression = Literal(confidence),
-      seedExpression = Literal(seed))
-    val wrongConfidence = new CountMinSketchAgg(
-      childExpression,
-      epsExpression = Literal(epsOfTotalCount),
-      confidenceExpression = Cast(Literal(null), DoubleType),
-      seedExpression = Literal(seed))
-    val wrongSeed = new CountMinSketchAgg(
-      childExpression,
-      epsExpression = Literal(epsOfTotalCount),
-      confidenceExpression = Literal(confidence),
-      seedExpression = Cast(Literal(null), IntegerType))
+    val wrongEps = cms(null, confidence, seed)
+    val wrongConfidence = cms(epsOfTotalCount, null, seed)
+    val wrongSeed = cms(epsOfTotalCount, confidence, null)
 
     Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg =>
-      assertEqual(
-        wrongAgg.checkInputDataTypes(),
-        TypeCheckFailure("The eps, confidence or seed provided should not be null")
-      )
+      assertResult(TypeCheckFailure("The eps, confidence or seed provided should not be null")) {
+        wrongAgg.checkInputDataTypes()
+      }
     }
 
     // parameters are out of the valid range
     Seq(0.0, -1000.0).foreach { invalidEps =>
-      val invalidAgg = new CountMinSketchAgg(
-        childExpression,
-        epsExpression = Literal(invalidEps),
-        confidenceExpression = Literal(confidence),
-        seedExpression = Literal(seed))
-      assertEqual(
-        invalidAgg.checkInputDataTypes(),
-        TypeCheckFailure(s"Relative error must be positive (current value = $invalidEps)")
-      )
+      val invalidAgg = cms(invalidEps, confidence, seed)
+      assertResult(
+        TypeCheckFailure(s"Relative error must be positive (current value = $invalidEps)")) {
+        invalidAgg.checkInputDataTypes()
+      }
     }
 
     Seq(0.0, 1.0, -2.0, 2.0).foreach { invalidConfidence =>
-      val invalidAgg = new CountMinSketchAgg(
-        childExpression,
-        epsExpression = Literal(epsOfTotalCount),
-        confidenceExpression = Literal(invalidConfidence),
-        seedExpression = Literal(seed))
-      assertEqual(
-        invalidAgg.checkInputDataTypes(),
-        TypeCheckFailure(
-          s"Confidence must be within range (0.0, 1.0) (current value = $invalidConfidence)")
-      )
+      val invalidAgg = cms(epsOfTotalCount, invalidConfidence, seed)
+      assertResult(TypeCheckFailure(
+        s"Confidence must be within range (0.0, 1.0) (current value = $invalidConfidence)")) {
+        invalidAgg.checkInputDataTypes()
+      }
     }
   }
 
-  private def assertEqual[T](left: T, right: T): Unit = {
-    assert(left == right)
-  }
-
   test("null handling") {
     def isEqual(result: Any, other: CountMinSketch): Boolean = {
-      result match {
-        case bytesData: Array[Byte] =>
-          val in = new ByteArrayInputStream(bytesData)
-          val cms = CountMinSketch.readFrom(in)
-          cms.equals(other)
-        case _ => fail("unexpected return type")
-      }
+      other.equals(CountMinSketch.readFrom(result.asInstanceOf[Array[Byte]]))
     }
 
-    val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), Literal(confidence),
-      Literal(seed))
+    val agg = cms(epsOfTotalCount, confidence, seed)
     val emptyCms = CountMinSketch.create(epsOfTotalCount, confidence, seed)
     val buffer = new GenericInternalRow(new Array[Any](1))
     agg.initialize(buffer)
     // Empty aggregation buffer
     assert(isEqual(agg.eval(buffer), emptyCms))
+
     // Empty input row
     agg.update(buffer, InternalRow(null))
     assert(isEqual(agg.eval(buffer), emptyCms))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
index e98092df49518b691bdd2eeb993c6ad75e04d14a..62a75343a094608637771bf1bdf3cd2d328c5b07 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
@@ -21,6 +21,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
 import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
 import org.apache.spark.sql.test.SharedSQLContext
 
+/**
+ * End-to-end tests for approximate percentile aggregate function.
+ */
 class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext {
   import testImplicits._
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
index 3e715a393e530f63c686344bbd87201fb221dde5..dea0d4c0c6d405224ef58556b4515ccc7f856ab0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
@@ -17,175 +17,29 @@
 
 package org.apache.spark.sql
 
-import java.io.ByteArrayInputStream
-import java.nio.charset.StandardCharsets
-import java.sql.{Date, Timestamp}
-
-import scala.reflect.ClassTag
-import scala.util.Random
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{Decimal, StringType, _}
-import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.sketch.CountMinSketch
 
+/**
+ * End-to-end test suite for count_min_sketch.
+ */
 class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
 
-  private val table = "count_min_sketch_table"
-
-  /** Uses fixed seed to ensure reproducible test execution */
-  private val r = new Random(42)
-  private val numAllItems = 1000
-  private val numSamples = numAllItems / 10
-
-  private val eps = 0.1D
-  private val confidence = 0.95D
-  private val seed = 11
-
-  val startDate = DateTimeUtils.fromJavaDate(Date.valueOf("1900-01-01"))
-  val endDate = DateTimeUtils.fromJavaDate(Date.valueOf("2016-01-01"))
-  val startTS = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("1900-01-01 00:00:00"))
-  val endTS = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-01-01 00:00:00"))
-
-  test(s"compute count-min sketch for multiple columns of different types") {
-    val (allBytes, sampledByteIndices, exactByteFreq) =
-      generateTestData[Byte] { _.nextInt().toByte }
-    val (allShorts, sampledShortIndices, exactShortFreq) =
-      generateTestData[Short] { _.nextInt().toShort }
-    val (allInts, sampledIntIndices, exactIntFreq) =
-      generateTestData[Int] { _.nextInt() }
-    val (allLongs, sampledLongIndices, exactLongFreq) =
-      generateTestData[Long] { _.nextLong() }
-    val (allStrings, sampledStringIndices, exactStringFreq) =
-      generateTestData[String] { r => r.nextString(r.nextInt(20)) }
-    val (allDates, sampledDateIndices, exactDateFreq) = generateTestData[Date] { r =>
-      DateTimeUtils.toJavaDate(r.nextInt(endDate - startDate) + startDate)
-    }
-    val (allTimestamps, sampledTSIndices, exactTSFreq) = generateTestData[Timestamp] { r =>
-      DateTimeUtils.toJavaTimestamp(r.nextLong() % (endTS - startTS) + startTS)
-    }
-    val (allFloats, sampledFloatIndices, exactFloatFreq) =
-      generateTestData[Float] { _.nextFloat() }
-    val (allDoubles, sampledDoubleIndices, exactDoubleFreq) =
-      generateTestData[Double] { _.nextDouble() }
-    val (allDeciamls, sampledDecimalIndices, exactDecimalFreq) =
-      generateTestData[Decimal] { r => Decimal(r.nextDouble()) }
-    val (allBooleans, sampledBooleanIndices, exactBooleanFreq) =
-      generateTestData[Boolean] { _.nextBoolean() }
-    val (allBinaries, sampledBinaryIndices, exactBinaryFreq) = generateTestData[Array[Byte]] { r =>
-      r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8)
-    }
-
-    val data = (0 until numSamples).map { i =>
-      Row(allBytes(sampledByteIndices(i)),
-        allShorts(sampledShortIndices(i)),
-        allInts(sampledIntIndices(i)),
-        allLongs(sampledLongIndices(i)),
-        allStrings(sampledStringIndices(i)),
-        allDates(sampledDateIndices(i)),
-        allTimestamps(sampledTSIndices(i)),
-        allFloats(sampledFloatIndices(i)),
-        allDoubles(sampledDoubleIndices(i)),
-        allDeciamls(sampledDecimalIndices(i)),
-        allBooleans(sampledBooleanIndices(i)),
-        allBinaries(sampledBinaryIndices(i)))
-    }
+  test("count-min sketch") {
+    import testImplicits._
 
-    val schema = StructType(Seq(
-      StructField("c1", ByteType),
-      StructField("c2", ShortType),
-      StructField("c3", IntegerType),
-      StructField("c4", LongType),
-      StructField("c5", StringType),
-      StructField("c6", DateType),
-      StructField("c7", TimestampType),
-      StructField("c8", FloatType),
-      StructField("c9", DoubleType),
-      StructField("c10", new DecimalType()),
-      StructField("c11", BooleanType),
-      StructField("c12", BinaryType)))
+    val eps = 0.1
+    val confidence = 0.95
+    val seed = 11
 
-    withTempView(table) {
-      val rdd: RDD[Row] = spark.sparkContext.parallelize(data)
-      spark.createDataFrame(rdd, schema).createOrReplaceTempView(table)
+    val items = Seq(1, 1, 2, 2, 2, 2, 3, 4, 5)
+    val sketch = CountMinSketch.readFrom(items.toDF("id")
+      .selectExpr(s"count_min_sketch(id, ${eps}d, ${confidence}d, $seed)")
+      .head().get(0).asInstanceOf[Array[Byte]])
 
-      val cmsSql = schema.fieldNames.map { col =>
-        s"count_min_sketch($col, ${eps}D, ${confidence}D, $seed)"
-      }
-      val result = sql(s"SELECT ${cmsSql.mkString(", ")} FROM $table").head()
-      schema.indices.foreach { i =>
-        val binaryData = result.getAs[Array[Byte]](i)
-        val in = new ByteArrayInputStream(binaryData)
-        val cms = CountMinSketch.readFrom(in)
-        schema.fields(i).dataType match {
-          case ByteType => checkResult(cms, allBytes, exactByteFreq)
-          case ShortType => checkResult(cms, allShorts, exactShortFreq)
-          case IntegerType => checkResult(cms, allInts, exactIntFreq)
-          case LongType => checkResult(cms, allLongs, exactLongFreq)
-          case StringType => checkResult(cms, allStrings, exactStringFreq)
-          case DateType =>
-            checkResult(cms,
-              allDates.map(DateTimeUtils.fromJavaDate),
-              exactDateFreq.map { e =>
-                (DateTimeUtils.fromJavaDate(e._1.asInstanceOf[Date]), e._2)
-              })
-          case TimestampType =>
-            checkResult(cms,
-              allTimestamps.map(DateTimeUtils.fromJavaTimestamp),
-              exactTSFreq.map { e =>
-                (DateTimeUtils.fromJavaTimestamp(e._1.asInstanceOf[Timestamp]), e._2)
-              })
-          case FloatType => checkResult(cms, allFloats, exactFloatFreq)
-          case DoubleType => checkResult(cms, allDoubles, exactDoubleFreq)
-          case DecimalType() => checkResult(cms, allDeciamls, exactDecimalFreq)
-          case BooleanType => checkResult(cms, allBooleans, exactBooleanFreq)
-          case BinaryType => checkResult(cms, allBinaries, exactBinaryFreq)
-        }
-      }
-    }
-  }
-
-  private def checkResult[T: ClassTag](
-      cms: CountMinSketch,
-      data: Array[T],
-      exactFreq: Map[Any, Long]): Unit = {
-    val probCorrect = {
-      val numErrors = data.map { i =>
-        val count = exactFreq.getOrElse(getProbeItem(i), 0L)
-        val item = i match {
-          case dec: Decimal => dec.toJavaBigDecimal
-          case str: UTF8String => str.getBytes
-          case _ => i
-        }
-        val ratio = (cms.estimateCount(item) - count).toDouble / data.length
-        if (ratio > eps) 1 else 0
-      }.sum
-
-      1D - numErrors.toDouble / data.length
-    }
-
-    assert(
-      probCorrect > confidence,
-      s"Confidence not reached: required $confidence, reached $probCorrect"
-    )
-  }
-
-  private def getProbeItem[T: ClassTag](item: T): Any = item match {
-    // Use a string to represent the content of an array of bytes
-    case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
-    case i => identity(i)
-  }
+    val reference = CountMinSketch.create(eps, confidence, seed)
+    items.foreach(reference.add)
 
-  private def generateTestData[T: ClassTag](
-      itemGenerator: Random => T): (Array[T], Array[Int], Map[Any, Long]) = {
-    val allItems = Array.fill(numAllItems)(itemGenerator(r))
-    val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
-    val exactFreq = {
-      val sampledItems = sampledItemIndices.map(allItems)
-      sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
-    }
-    (allItems, sampledItemIndices, exactFreq)
+    assert(sketch == reference)
   }
 }