diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
index f24b240956a61114506d420841c3257b901eeaa8..3d4efef953a64e1c9bbc7d9d794e081c85fa2e58 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
@@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{AnalysisException, Row}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
@@ -74,11 +75,10 @@ case class Statistics(
  * Statistics collected for a column.
  *
  * 1. Supported data types are defined in `ColumnStat.supportsType`.
- * 2. The JVM data type stored in min/max is the external data type (used in Row) for the
- * corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for
- * TimestampType we store java.sql.Timestamp.
- * 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs.
- * 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms
+ * 2. The JVM data type stored in min/max is the internal data type for the corresponding
+ *    Catalyst data type. For example, the internal type of DateType is Int, and that the internal
+ *    type of TimestampType is Long.
+ * 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms
  *    (sketches) might have been used, and the data collected can also be stale.
  *
  * @param distinctCount number of distinct values
@@ -104,22 +104,43 @@ case class ColumnStat(
   /**
    * Returns a map from string to string that can be used to serialize the column stats.
    * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string
-   * representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]].
+   * representation for the value. min/max values are converted to the external data type. For
+   * example, for DateType we store java.sql.Date, and for TimestampType we store
+   * java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]].
    *
    * As part of the protocol, the returned map always contains a key called "version".
    * In the case min/max values are null (None), they won't appear in the map.
    */
-  def toMap: Map[String, String] = {
+  def toMap(colName: String, dataType: DataType): Map[String, String] = {
     val map = new scala.collection.mutable.HashMap[String, String]
     map.put(ColumnStat.KEY_VERSION, "1")
     map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString)
     map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString)
     map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString)
     map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString)
-    min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) }
-    max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) }
+    min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) }
+    max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) }
     map.toMap
   }
+
+  /**
+   * Converts the given value from Catalyst data type to string representation of external
+   * data type.
+   */
+  private def toExternalString(v: Any, colName: String, dataType: DataType): String = {
+    val externalValue = dataType match {
+      case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int])
+      case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long])
+      case BooleanType | _: IntegralType | FloatType | DoubleType => v
+      case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal
+      // This version of Spark does not use min/max for binary/string types so we ignore it.
+      case _ =>
+        throw new AnalysisException("Column statistics deserialization is not supported for " +
+          s"column $colName of data type: $dataType.")
+    }
+    externalValue.toString
+  }
+
 }
 
 
@@ -150,28 +171,15 @@ object ColumnStat extends Logging {
    * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats
    * from some external storage. The serialization side is defined in [[ColumnStat.toMap]].
    */
-  def fromMap(table: String, field: StructField, map: Map[String, String])
-    : Option[ColumnStat] = {
-    val str2val: (String => Any) = field.dataType match {
-      case _: IntegralType => _.toLong
-      case _: DecimalType => new java.math.BigDecimal(_)
-      case DoubleType | FloatType => _.toDouble
-      case BooleanType => _.toBoolean
-      case DateType => java.sql.Date.valueOf
-      case TimestampType => java.sql.Timestamp.valueOf
-      // This version of Spark does not use min/max for binary/string types so we ignore it.
-      case BinaryType | StringType => _ => null
-      case _ =>
-        throw new AnalysisException("Column statistics deserialization is not supported for " +
-          s"column ${field.name} of data type: ${field.dataType}.")
-    }
-
+  def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = {
     try {
       Some(ColumnStat(
         distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong),
         // Note that flatMap(Option.apply) turns Option(null) into None.
-        min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply),
-        max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply),
+        min = map.get(KEY_MIN_VALUE)
+          .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply),
+        max = map.get(KEY_MAX_VALUE)
+          .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply),
         nullCount = BigInt(map(KEY_NULL_COUNT).toLong),
         avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong,
         maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong
@@ -183,6 +191,30 @@ object ColumnStat extends Logging {
     }
   }
 
+  /**
+   * Converts from string representation of external data type to the corresponding Catalyst data
+   * type.
+   */
+  private def fromExternalString(s: String, name: String, dataType: DataType): Any = {
+    dataType match {
+      case BooleanType => s.toBoolean
+      case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s))
+      case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s))
+      case ByteType => s.toByte
+      case ShortType => s.toShort
+      case IntegerType => s.toInt
+      case LongType => s.toLong
+      case FloatType => s.toFloat
+      case DoubleType => s.toDouble
+      case _: DecimalType => Decimal(s)
+      // This version of Spark does not use min/max for binary/string types so we ignore it.
+      case BinaryType | StringType => null
+      case _ =>
+        throw new AnalysisException("Column statistics deserialization is not supported for " +
+          s"column $name of data type: $dataType.")
+    }
+  }
+
   /**
    * Constructs an expression to compute column statistics for a given column.
    *
@@ -232,11 +264,14 @@ object ColumnStat extends Logging {
   }
 
   /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */
-  def rowToColumnStat(row: Row): ColumnStat = {
+  def rowToColumnStat(row: Row, attr: Attribute): ColumnStat = {
     ColumnStat(
       distinctCount = BigInt(row.getLong(0)),
-      min = Option(row.get(1)),  // for string/binary min/max, get should return null
-      max = Option(row.get(2)),
+      // for string/binary min/max, get should return null
+      min = Option(row.get(1))
+        .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply),
+      max = Option(row.get(2))
+        .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply),
       nullCount = BigInt(row.getLong(3)),
       avgLen = row.getLong(4),
       maxLen = row.getLong(5)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index 5577233ffa6fe0ff68a882125f3869559d87bf2e..f1aff62cb6af0d3f7c25e260fcc626816d2bb628 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -22,7 +22,7 @@ import scala.math.BigDecimal.RoundingMode
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
 import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics}
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DataType, StringType}
+import org.apache.spark.sql.types.{DecimalType, _}
 
 
 object EstimationUtils {
@@ -75,4 +75,32 @@ object EstimationUtils {
     // (simple computation of statistics returns product of children).
     if (outputRowCount > 0) outputRowCount * sizePerRow else 1
   }
+
+  /**
+   * For simplicity we use Decimal to unify operations for data types whose min/max values can be
+   * represented as numbers, e.g. Boolean can be represented as 0 (false) or 1 (true).
+   * The two methods below are the contract of conversion.
+   */
+  def toDecimal(value: Any, dataType: DataType): Decimal = {
+    dataType match {
+      case _: NumericType | DateType | TimestampType => Decimal(value.toString)
+      case BooleanType => if (value.asInstanceOf[Boolean]) Decimal(1) else Decimal(0)
+    }
+  }
+
+  def fromDecimal(dec: Decimal, dataType: DataType): Any = {
+    dataType match {
+      case BooleanType => dec.toLong == 1
+      case DateType => dec.toInt
+      case TimestampType => dec.toLong
+      case ByteType => dec.toByte
+      case ShortType => dec.toShort
+      case IntegerType => dec.toInt
+      case LongType => dec.toLong
+      case FloatType => dec.toFloat
+      case DoubleType => dec.toDouble
+      case _: DecimalType => dec
+    }
+  }
+
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
index 7bd8e6511232f38dd03bef5a01df2088c60dac70..4b6b3b14d9ac8c37ea0d3ffbf514447715ab167f 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -25,7 +25,6 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
 import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics}
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -301,30 +300,6 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
     }
   }
 
-  /**
-   * For a SQL data type, its internal data type may be different from its external type.
-   * For DateType, its internal type is Int, and its external data type is Java Date type.
-   * The min/max values in ColumnStat are saved in their corresponding external type.
-   *
-   * @param attrDataType the column data type
-   * @param litValue the literal value
-   * @return a BigDecimal value
-   */
-  def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = {
-    attrDataType match {
-      case DateType =>
-        Some(DateTimeUtils.toJavaDate(litValue.toString.toInt))
-      case TimestampType =>
-        Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong))
-      case _: DecimalType =>
-        Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal)
-      case StringType | BinaryType =>
-        None
-      case _ =>
-        Some(litValue)
-    }
-  }
-
   /**
    * Returns a percentage of rows meeting an equality (=) expression.
    * This method evaluates the equality predicate for all data types.
@@ -356,12 +331,16 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
     val statsRange = Range(colStat.min, colStat.max, attr.dataType)
     if (statsRange.contains(literal)) {
       if (update) {
-        // We update ColumnStat structure after apply this equality predicate.
-        // Set distinctCount to 1.  Set nullCount to 0.
-        // Need to save new min/max using the external type value of the literal
-        val newValue = convertBoundValue(attr.dataType, literal.value)
-        val newStats = colStat.copy(distinctCount = 1, min = newValue,
-          max = newValue, nullCount = 0)
+        // We update ColumnStat structure after apply this equality predicate:
+        // Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal
+        // value.
+        val newStats = attr.dataType match {
+          case StringType | BinaryType =>
+            colStat.copy(distinctCount = 1, nullCount = 0)
+          case _ =>
+            colStat.copy(distinctCount = 1, min = Some(literal.value),
+              max = Some(literal.value), nullCount = 0)
+        }
         colStatsMap(attr) = newStats
       }
 
@@ -430,18 +409,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
           return Some(0.0)
         }
 
-        // Need to save new min/max using the external type value of the literal
-        val newMax = convertBoundValue(
-          attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString)))
-        val newMin = convertBoundValue(
-          attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString)))
-
+        val newMax = validQuerySet.maxBy(EstimationUtils.toDecimal(_, dataType))
+        val newMin = validQuerySet.minBy(EstimationUtils.toDecimal(_, dataType))
         // newNdv should not be greater than the old ndv.  For example, column has only 2 values
         // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5.
         newNdv = ndv.min(BigInt(validQuerySet.size))
         if (update) {
-          val newStats = colStat.copy(distinctCount = newNdv, min = newMin,
-                max = newMax, nullCount = 0)
+          val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin),
+            max = Some(newMax), nullCount = 0)
           colStatsMap(attr) = newStats
         }
 
@@ -478,8 +453,8 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
 
     val colStat = colStatsMap(attr)
     val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange]
-    val max = BigDecimal(statsRange.max)
-    val min = BigDecimal(statsRange.min)
+    val max = statsRange.max.toBigDecimal
+    val min = statsRange.min.toBigDecimal
     val ndv = BigDecimal(colStat.distinctCount)
 
     // determine the overlapping degree between predicate range and column's range
@@ -540,8 +515,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
       }
 
       if (update) {
-        // Need to save new min/max using the external type value of the literal
-        val newValue = convertBoundValue(attr.dataType, literal.value)
+        val newValue = Some(literal.value)
         var newMax = colStat.max
         var newMin = colStat.min
         var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt()
@@ -606,14 +580,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging
     val colStatLeft = colStatsMap(attrLeft)
     val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType)
       .asInstanceOf[NumericRange]
-    val maxLeft = BigDecimal(statsRangeLeft.max)
-    val minLeft = BigDecimal(statsRangeLeft.min)
+    val maxLeft = statsRangeLeft.max
+    val minLeft = statsRangeLeft.min
 
     val colStatRight = colStatsMap(attrRight)
     val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType)
       .asInstanceOf[NumericRange]
-    val maxRight = BigDecimal(statsRangeRight.max)
-    val minRight = BigDecimal(statsRangeRight.min)
+    val maxRight = statsRangeRight.max
+    val minRight = statsRangeRight.min
 
     // determine the overlapping degree between predicate range and column's range
     val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
index 3d13967cb62a4358ebc2974e894d657752783b56..4ac5ba5689f82de2b775f6b93c51c93e2b954933 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala
@@ -17,12 +17,8 @@
 
 package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
 
-import java.math.{BigDecimal => JDecimal}
-import java.sql.{Date, Timestamp}
-
 import org.apache.spark.sql.catalyst.expressions.Literal
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _}
+import org.apache.spark.sql.types._
 
 
 /** Value range of a column. */
@@ -31,13 +27,10 @@ trait Range {
 }
 
 /** For simplicity we use decimal to unify operations of numeric ranges. */
-case class NumericRange(min: JDecimal, max: JDecimal) extends Range {
+case class NumericRange(min: Decimal, max: Decimal) extends Range {
   override def contains(l: Literal): Boolean = {
-    val decimal = l.dataType match {
-      case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0)
-      case _ => new JDecimal(l.value.toString)
-    }
-    min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0
+    val lit = EstimationUtils.toDecimal(l.value, l.dataType)
+    min <= lit && max >= lit
   }
 }
 
@@ -58,7 +51,10 @@ object Range {
   def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match {
     case StringType | BinaryType => new DefaultRange()
     case _ if min.isEmpty || max.isEmpty => new NullRange()
-    case _ => toNumericRange(min.get, max.get, dataType)
+    case _ =>
+      NumericRange(
+        min = EstimationUtils.toDecimal(min.get, dataType),
+        max = EstimationUtils.toDecimal(max.get, dataType))
   }
 
   def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match {
@@ -82,51 +78,11 @@ object Range {
         // binary/string types don't support intersecting.
         (None, None)
       case (n1: NumericRange, n2: NumericRange) =>
-        val newRange = NumericRange(n1.min.max(n2.min), n1.max.min(n2.max))
-        val (newMin, newMax) = fromNumericRange(newRange, dt)
-        (Some(newMin), Some(newMax))
+        // Choose the maximum of two min values, and the minimum of two max values.
+        val newMin = if (n1.min <= n2.min) n2.min else n1.min
+        val newMax = if (n1.max <= n2.max) n1.max else n2.max
+        (Some(EstimationUtils.fromDecimal(newMin, dt)),
+          Some(EstimationUtils.fromDecimal(newMax, dt)))
     }
   }
-
-  /**
-   * For simplicity we use decimal to unify operations of numeric types, the two methods below
-   * are the contract of conversion.
-   */
-  private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = {
-    dataType match {
-      case _: NumericType =>
-        NumericRange(new JDecimal(min.toString), new JDecimal(max.toString))
-      case BooleanType =>
-        val min1 = if (min.asInstanceOf[Boolean]) 1 else 0
-        val max1 = if (max.asInstanceOf[Boolean]) 1 else 0
-        NumericRange(new JDecimal(min1), new JDecimal(max1))
-      case DateType =>
-        val min1 = DateTimeUtils.fromJavaDate(min.asInstanceOf[Date])
-        val max1 = DateTimeUtils.fromJavaDate(max.asInstanceOf[Date])
-        NumericRange(new JDecimal(min1), new JDecimal(max1))
-      case TimestampType =>
-        val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp])
-        val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp])
-        NumericRange(new JDecimal(min1), new JDecimal(max1))
-    }
-  }
-
-  private def fromNumericRange(n: NumericRange, dataType: DataType): (Any, Any) = {
-    dataType match {
-      case _: IntegralType =>
-        (n.min.longValue(), n.max.longValue())
-      case FloatType | DoubleType =>
-        (n.min.doubleValue(), n.max.doubleValue())
-      case _: DecimalType =>
-        (n.min, n.max)
-      case BooleanType =>
-        (n.min.longValue() == 1, n.max.longValue() == 1)
-      case DateType =>
-        (DateTimeUtils.toJavaDate(n.min.intValue()), DateTimeUtils.toJavaDate(n.max.intValue()))
-      case TimestampType =>
-        (DateTimeUtils.toJavaTimestamp(n.min.longValue()),
-          DateTimeUtils.toJavaTimestamp(n.max.longValue()))
-    }
-  }
-
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
index cffb0d87392874a151e69215e84c29299036f138..a28447840ae097a081d4e26293ced451e3172432 100755
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
 import org.apache.spark.sql.catalyst.plans.LeftOuter
 import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics}
 import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
 
 /**
@@ -45,15 +46,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
     nullCount = 0, avgLen = 1, maxLen = 1)
 
   // column cdate has 10 values from 2017-01-01 through 2017-01-10.
-  val dMin = Date.valueOf("2017-01-01")
-  val dMax = Date.valueOf("2017-01-10")
+  val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01"))
+  val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10"))
   val attrDate = AttributeReference("cdate", DateType)()
   val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax),
     nullCount = 0, avgLen = 4, maxLen = 4)
 
   // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20.
-  val decMin = new java.math.BigDecimal("0.200000000000000000")
-  val decMax = new java.math.BigDecimal("0.800000000000000000")
+  val decMin = Decimal("0.200000000000000000")
+  val decMax = Decimal("0.800000000000000000")
   val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))()
   val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax),
     nullCount = 0, avgLen = 8, maxLen = 8)
@@ -147,7 +148,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
 
   test("cint < 3 OR null") {
     val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType))
-    val m = Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)).stats(conf)
     validateEstimatedStats(
       Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
       Seq(attrInt -> colStatInt),
@@ -341,6 +341,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
       expectedRowCount = 7)
   }
 
+  test("cbool IN (true)") {
+    validateEstimatedStats(
+      Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)),
+      Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true),
+        nullCount = 0, avgLen = 1, maxLen = 1)),
+      expectedRowCount = 5)
+  }
+
   test("cbool = true") {
     validateEstimatedStats(
       Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)),
@@ -358,9 +366,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
   }
 
   test("cdate = cast('2017-01-02' AS DATE)") {
-    val d20170102 = Date.valueOf("2017-01-02")
+    val d20170102 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-02"))
     validateEstimatedStats(
-      Filter(EqualTo(attrDate, Literal(d20170102)),
+      Filter(EqualTo(attrDate, Literal(d20170102, DateType)),
         childStatsTestPlan(Seq(attrDate), 10L)),
       Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102),
         nullCount = 0, avgLen = 4, maxLen = 4)),
@@ -368,9 +376,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
   }
 
   test("cdate < cast('2017-01-03' AS DATE)") {
-    val d20170103 = Date.valueOf("2017-01-03")
+    val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03"))
     validateEstimatedStats(
-      Filter(LessThan(attrDate, Literal(d20170103)),
+      Filter(LessThan(attrDate, Literal(d20170103, DateType)),
         childStatsTestPlan(Seq(attrDate), 10L)),
       Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103),
         nullCount = 0, avgLen = 4, maxLen = 4)),
@@ -379,19 +387,19 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
 
   test("""cdate IN ( cast('2017-01-03' AS DATE),
       cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") {
-    val d20170103 = Date.valueOf("2017-01-03")
-    val d20170104 = Date.valueOf("2017-01-04")
-    val d20170105 = Date.valueOf("2017-01-05")
+    val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03"))
+    val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04"))
+    val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05"))
     validateEstimatedStats(
-      Filter(In(attrDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))),
-        childStatsTestPlan(Seq(attrDate), 10L)),
+      Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType),
+        Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)),
       Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105),
         nullCount = 0, avgLen = 4, maxLen = 4)),
       expectedRowCount = 3)
   }
 
   test("cdecimal = 0.400000000000000000") {
-    val dec_0_40 = new java.math.BigDecimal("0.400000000000000000")
+    val dec_0_40 = Decimal("0.400000000000000000")
     validateEstimatedStats(
       Filter(EqualTo(attrDecimal, Literal(dec_0_40)),
         childStatsTestPlan(Seq(attrDecimal), 4L)),
@@ -401,7 +409,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
   }
 
   test("cdecimal < 0.60 ") {
-    val dec_0_60 = new java.math.BigDecimal("0.600000000000000000")
+    val dec_0_60 = Decimal("0.600000000000000000")
     validateEstimatedStats(
       Filter(LessThan(attrDecimal, Literal(dec_0_60)),
         childStatsTestPlan(Seq(attrDecimal), 4L)),
@@ -532,7 +540,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
 
   test("cint = cint3") {
     // no records qualify due to no overlap
-    val emptyColStats = Seq[(Attribute, ColumnStat)]()
     validateEstimatedStats(
       Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)),
       Nil, // set to empty
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
index f62df842fa50ad7fb5111d0a2cd5639fd5f6cf8f..2d6b6e8e21f34f6f17b73b4cf24542110b26b8e5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap,
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics}
 import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types.{DateType, TimestampType, _}
 
 
@@ -254,24 +255,24 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
   test("test join keys of different types") {
     /** Columns in a table with only one row */
     def genColumnData: mutable.LinkedHashMap[Attribute, ColumnStat] = {
-      val dec = new java.math.BigDecimal("1.000000000000000000")
-      val date = Date.valueOf("2016-05-08")
-      val timestamp = Timestamp.valueOf("2016-05-08 00:00:01")
+      val dec = Decimal("1.000000000000000000")
+      val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08"))
+      val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01"))
       mutable.LinkedHashMap[Attribute, ColumnStat](
         AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1,
           min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1),
         AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1,
-          min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 1, maxLen = 1),
+          min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1),
         AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1,
-          min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 2, maxLen = 2),
+          min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2),
         AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1,
-          min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 4, maxLen = 4),
+          min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4),
         AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1,
           min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8),
         AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1,
           min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8),
         AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1,
-          min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 4, maxLen = 4),
+          min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4),
         AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1,
           min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16),
         AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
index f408dc4153586409035fe095073a42bae696b651..a5c4d22a29386a27d0a97d672c5e2dc8c1b7e190 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala
@@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
 
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference}
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
 
 
@@ -62,28 +63,28 @@ class ProjectEstimationSuite extends StatsEstimationTestBase {
   }
 
   test("test row size estimation") {
-    val dec1 = new java.math.BigDecimal("1.000000000000000000")
-    val dec2 = new java.math.BigDecimal("8.000000000000000000")
-    val d1 = Date.valueOf("2016-05-08")
-    val d2 = Date.valueOf("2016-05-09")
-    val t1 = Timestamp.valueOf("2016-05-08 00:00:01")
-    val t2 = Timestamp.valueOf("2016-05-09 00:00:02")
+    val dec1 = Decimal("1.000000000000000000")
+    val dec2 = Decimal("8.000000000000000000")
+    val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08"))
+    val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-09"))
+    val t1 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01"))
+    val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02"))
 
     val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
       AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2,
         min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1),
       AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2,
-        min = Some(1L), max = Some(2L), nullCount = 0, avgLen = 1, maxLen = 1),
+        min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1),
       AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2,
-        min = Some(1L), max = Some(3L), nullCount = 0, avgLen = 2, maxLen = 2),
+        min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2),
       AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2,
-        min = Some(1L), max = Some(4L), nullCount = 0, avgLen = 4, maxLen = 4),
+        min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4),
       AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2,
         min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8),
       AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2,
         min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8),
       AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2,
-        min = Some(1.0), max = Some(7.0), nullCount = 0, avgLen = 4, maxLen = 4),
+        min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4),
       AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2,
         min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16),
       AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
index b89014ed8ef54f3af68289c3dec1ec95d029cf24..0d8db2ff5d5a03a931d70beb0de34789032b1502 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
@@ -73,10 +73,10 @@ case class AnalyzeColumnCommand(
     val relation = sparkSession.table(tableIdent).logicalPlan
     // Resolve the column names and dedup using AttributeSet
     val resolver = sparkSession.sessionState.conf.resolver
-    val attributesToAnalyze = AttributeSet(columnNames.map { col =>
+    val attributesToAnalyze = columnNames.map { col =>
       val exprOption = relation.output.find(attr => resolver(attr.name, col))
       exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist."))
-    }).toSeq
+    }
 
     // Make sure the column types are supported for stats gathering.
     attributesToAnalyze.foreach { attr =>
@@ -99,8 +99,8 @@ case class AnalyzeColumnCommand(
     val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head()
 
     val rowCount = statsRow.getLong(0)
-    val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) =>
-      (expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1)))
+    val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) =>
+      (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1), attr))
     }.toMap
     (rowCount, columnStats)
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
index 1f547c5a2a8ff628bb251b2bb70c463b40e92423..ddc393c8da053060082e21cc6f385872d53fd577 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -26,6 +26,7 @@ import scala.util.Random
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics}
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.internal.StaticSQLConf
 import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
@@ -117,7 +118,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
     val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)
     stats.zip(df.schema).foreach { case ((k, v), field) =>
       withClue(s"column $k with type ${field.dataType}") {
-        val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap)
+        val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType))
         assert(roundtrip == Some(v))
       }
     }
@@ -201,17 +202,19 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils
   /** A mapping from column to the stats collected. */
   protected val stats = mutable.LinkedHashMap(
     "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1),
-    "cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1),
-    "cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2),
-    "cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4),
+    "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1),
+    "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2),
+    "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4),
     "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8),
     "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8),
-    "cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4),
-    "cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16),
+    "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4),
+    "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16),
     "cstring" -> ColumnStat(2, None, None, 1, 3, 3),
     "cbinary" -> ColumnStat(2, None, None, 1, 3, 3),
-    "cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4),
-    "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8)
+    "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)),
+      Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4),
+    "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)),
+      Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8)
   )
 
   private val randomName = new Random(31)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
index 806f2be5faeb06c3f459cc2bc8926a74d8bddfff..8b0fdf49cefabb453283da2b0efe38af85802120 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
@@ -526,8 +526,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
       if (stats.rowCount.isDefined) {
         statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString()
       }
+      val colNameTypeMap: Map[String, DataType] =
+        tableDefinition.schema.fields.map(f => (f.name, f.dataType)).toMap
       stats.colStats.foreach { case (colName, colStat) =>
-        colStat.toMap.foreach { case (k, v) =>
+        colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) =>
           statsProperties += (columnStatKeyPropName(colName, k) -> v)
         }
       }