diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index 5692e574d4c7e77b1490299f2287c2762493ec58..f0aac5bb00dfb32a479fe0ea1e1e129c986eab1a 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -115,15 +115,45 @@ abstract public class CountMinSketch {
   public abstract long totalCount();
 
   /**
-   * Adds 1 to {@code item}.
+   * Increments {@code item}'s count by one.
    */
   public abstract void add(Object item);
 
   /**
-   * Adds {@code count} to {@code item}.
+   * Increments {@code item}'s count by {@code count}.
    */
   public abstract void add(Object item, long count);
 
+  /**
+   * Increments {@code item}'s count by one.
+   */
+  public abstract void addLong(long item);
+
+  /**
+   * Increments {@code item}'s count by {@code count}.
+   */
+  public abstract void addLong(long item, long count);
+
+  /**
+   * Increments {@code item}'s count by one.
+   */
+  public abstract void addString(String item);
+
+  /**
+   * Increments {@code item}'s count by {@code count}.
+   */
+  public abstract void addString(String item, long count);
+
+  /**
+   * Increments {@code item}'s count by one.
+   */
+  public abstract void addBinary(byte[] item);
+
+  /**
+   * Increments {@code item}'s count by {@code count}.
+   */
+  public abstract void addBinary(byte[] item, long count);
+
   /**
    * Returns the estimated frequency of {@code item}.
    */
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index e49ae22906c4c828ead21d6c8a525b9671d20126..c0631c6778df48a6f8df7683a009358b46e45136 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -25,7 +25,6 @@ import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.io.OutputStream;
 import java.io.Serializable;
-import java.io.UnsupportedEncodingException;
 import java.util.Arrays;
 import java.util.Random;
 
@@ -146,27 +145,49 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
     }
   }
 
-  private void addString(String item, long count) {
+  @Override
+  public void addString(String item) {
+    addString(item, 1);
+  }
+
+  @Override
+  public void addString(String item, long count) {
+    addBinary(Utils.getBytesFromUTF8String(item), count);
+  }
+
+  @Override
+  public void addLong(long item) {
+    addLong(item, 1);
+  }
+
+  @Override
+  public void addLong(long item, long count) {
     if (count < 0) {
       throw new IllegalArgumentException("Negative increments not implemented");
     }
 
-    int[] buckets = getHashBuckets(item, depth, width);
-
     for (int i = 0; i < depth; ++i) {
-      table[i][buckets[i]] += count;
+      table[i][hash(item, i)] += count;
     }
 
     totalCount += count;
   }
 
-  private void addLong(long item, long count) {
+  @Override
+  public void addBinary(byte[] item) {
+    addBinary(item, 1);
+  }
+
+  @Override
+  public void addBinary(byte[] item, long count) {
     if (count < 0) {
       throw new IllegalArgumentException("Negative increments not implemented");
     }
 
+    int[] buckets = getHashBuckets(item, depth, width);
+
     for (int i = 0; i < depth; ++i) {
-      table[i][hash(item, i)] += count;
+      table[i][buckets[i]] += count;
     }
 
     totalCount += count;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index b0b6995a2214f0558be91c0583bba02721c80dea..bb3cc02800d5113a8962ef33a07bd5ea0d477d72 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.stat._
-import org.apache.spark.sql.types.{IntegralType, StringType}
+import org.apache.spark.sql.types._
 import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}
 
 /**
@@ -109,7 +109,6 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
    * Null elements will be replaced by "null", and back ticks will be dropped from elements if they
    * exist.
    *
-   *
    * @param col1 The name of the first column. Distinct items will make the first item of
    *             each row.
    * @param col2 The name of the second column. Distinct items will make the column names
@@ -374,21 +373,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
     val singleCol = df.select(col)
     val colType = singleCol.schema.head.dataType
 
-    require(
-      colType == StringType || colType.isInstanceOf[IntegralType],
-      s"Count-min Sketch only supports string type and integral types, " +
-        s"and does not support type $colType."
-    )
+    val updater: (CountMinSketch, InternalRow) => Unit = colType match {
+      // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
+      // instead of `addString` to avoid unnecessary conversion.
+      case StringType => (sketch, row) => sketch.addBinary(row.getUTF8String(0).getBytes)
+      case ByteType => (sketch, row) => sketch.addLong(row.getByte(0))
+      case ShortType => (sketch, row) => sketch.addLong(row.getShort(0))
+      case IntegerType => (sketch, row) => sketch.addLong(row.getInt(0))
+      case LongType => (sketch, row) => sketch.addLong(row.getLong(0))
+      case _ =>
+        throw new IllegalArgumentException(
+          s"Count-min Sketch only supports string type and integral types, " +
+            s"and does not support type $colType."
+        )
+    }
 
-    singleCol.rdd.aggregate(zero)(
-      (sketch: CountMinSketch, row: Row) => {
-        sketch.add(row.get(0))
+    singleCol.queryExecution.toRdd.aggregate(zero)(
+      (sketch: CountMinSketch, row: InternalRow) => {
+        updater(sketch, row)
         sketch
       },
-
-      (sketch1: CountMinSketch, sketch2: CountMinSketch) => {
-        sketch1.mergeInPlace(sketch2)
-      }
+      (sketch1, sketch2) => sketch1.mergeInPlace(sketch2)
     )
   }
 
@@ -447,19 +452,27 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
     require(colType == StringType || colType.isInstanceOf[IntegralType],
       s"Bloom filter only supports string type and integral types, but got $colType.")
 
-    val seqOp: (BloomFilter, InternalRow) => BloomFilter = if (colType == StringType) {
-      (filter, row) =>
-        // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
-        // instead of `putString` to avoid unnecessary conversion.
-        filter.putBinary(row.getUTF8String(0).getBytes)
-        filter
-    } else {
-      (filter, row) =>
-        // TODO: specialize it.
-        filter.putLong(row.get(0, colType).asInstanceOf[Number].longValue())
-        filter
+    val updater: (BloomFilter, InternalRow) => Unit = colType match {
+      // For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
+      // instead of `putString` to avoid unnecessary conversion.
+      case StringType => (filter, row) => filter.putBinary(row.getUTF8String(0).getBytes)
+      case ByteType => (filter, row) => filter.putLong(row.getByte(0))
+      case ShortType => (filter, row) => filter.putLong(row.getShort(0))
+      case IntegerType => (filter, row) => filter.putLong(row.getInt(0))
+      case LongType => (filter, row) => filter.putLong(row.getLong(0))
+      case _ =>
+        throw new IllegalArgumentException(
+          s"Bloom filter only supports string type and integral types, " +
+            s"and does not support type $colType."
+        )
     }
 
-    singleCol.queryExecution.toRdd.aggregate(zero)(seqOp, _ mergeInPlace _)
+    singleCol.queryExecution.toRdd.aggregate(zero)(
+      (filter: BloomFilter, row: InternalRow) => {
+        updater(filter, row)
+        filter
+      },
+      (filter1, filter2) => filter1.mergeInPlace(filter2)
+    )
   }
 }