diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index f24b240956a61114506d420841c3257b901eeaa8..3d4efef953a64e1c9bbc7d9d794e081c85fa2e58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -74,11 +75,10 @@ case class Statistics( * Statistics collected for a column. * * 1. Supported data types are defined in `ColumnStat.supportsType`. - * 2. The JVM data type stored in min/max is the external data type (used in Row) for the - * corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for - * TimestampType we store java.sql.Timestamp. - * 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs. - * 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms + * 2. The JVM data type stored in min/max is the internal data type for the corresponding + * Catalyst data type. For example, the internal type of DateType is Int, and that the internal + * type of TimestampType is Long. + * 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms * (sketches) might have been used, and the data collected can also be stale. * * @param distinctCount number of distinct values @@ -104,22 +104,43 @@ case class ColumnStat( /** * Returns a map from string to string that can be used to serialize the column stats. * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string - * representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]]. + * representation for the value. min/max values are converted to the external data type. For + * example, for DateType we store java.sql.Date, and for TimestampType we store + * java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]]. * * As part of the protocol, the returned map always contains a key called "version". * In the case min/max values are null (None), they won't appear in the map. */ - def toMap: Map[String, String] = { + def toMap(colName: String, dataType: DataType): Map[String, String] = { val map = new scala.collection.mutable.HashMap[String, String] map.put(ColumnStat.KEY_VERSION, "1") map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString) map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString) map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString) map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) - min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) } - max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) } + min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) } + max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) } map.toMap } + + /** + * Converts the given value from Catalyst data type to string representation of external + * data type. + */ + private def toExternalString(v: Any, colName: String, dataType: DataType): String = { + val externalValue = dataType match { + case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) + case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) + case BooleanType | _: IntegralType | FloatType | DoubleType => v + case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal + // This version of Spark does not use min/max for binary/string types so we ignore it. + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $colName of data type: $dataType.") + } + externalValue.toString + } + } @@ -150,28 +171,15 @@ object ColumnStat extends Logging { * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. */ - def fromMap(table: String, field: StructField, map: Map[String, String]) - : Option[ColumnStat] = { - val str2val: (String => Any) = field.dataType match { - case _: IntegralType => _.toLong - case _: DecimalType => new java.math.BigDecimal(_) - case DoubleType | FloatType => _.toDouble - case BooleanType => _.toBoolean - case DateType => java.sql.Date.valueOf - case TimestampType => java.sql.Timestamp.valueOf - // This version of Spark does not use min/max for binary/string types so we ignore it. - case BinaryType | StringType => _ => null - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column ${field.name} of data type: ${field.dataType}.") - } - + def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = { try { Some(ColumnStat( distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong), // Note that flatMap(Option.apply) turns Option(null) into None. - min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply), - max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply), + min = map.get(KEY_MIN_VALUE) + .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), + max = map.get(KEY_MAX_VALUE) + .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), nullCount = BigInt(map(KEY_NULL_COUNT).toLong), avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong @@ -183,6 +191,30 @@ object ColumnStat extends Logging { } } + /** + * Converts from string representation of external data type to the corresponding Catalyst data + * type. + */ + private def fromExternalString(s: String, name: String, dataType: DataType): Any = { + dataType match { + case BooleanType => s.toBoolean + case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) + case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) + case ByteType => s.toByte + case ShortType => s.toShort + case IntegerType => s.toInt + case LongType => s.toLong + case FloatType => s.toFloat + case DoubleType => s.toDouble + case _: DecimalType => Decimal(s) + // This version of Spark does not use min/max for binary/string types so we ignore it. + case BinaryType | StringType => null + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $name of data type: $dataType.") + } + } + /** * Constructs an expression to compute column statistics for a given column. * @@ -232,11 +264,14 @@ object ColumnStat extends Logging { } /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */ - def rowToColumnStat(row: Row): ColumnStat = { + def rowToColumnStat(row: Row, attr: Attribute): ColumnStat = { ColumnStat( distinctCount = BigInt(row.getLong(0)), - min = Option(row.get(1)), // for string/binary min/max, get should return null - max = Option(row.get(2)), + // for string/binary min/max, get should return null + min = Option(row.get(1)) + .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), + max = Option(row.get(2)) + .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), nullCount = BigInt(row.getLong(3)), avgLen = row.getLong(4), maxLen = row.getLong(5) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 5577233ffa6fe0ff68a882125f3869559d87bf2e..f1aff62cb6af0d3f7c25e260fcc626816d2bb628 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -22,7 +22,7 @@ import scala.math.BigDecimal.RoundingMode import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types.{DecimalType, _} object EstimationUtils { @@ -75,4 +75,32 @@ object EstimationUtils { // (simple computation of statistics returns product of children). if (outputRowCount > 0) outputRowCount * sizePerRow else 1 } + + /** + * For simplicity we use Decimal to unify operations for data types whose min/max values can be + * represented as numbers, e.g. Boolean can be represented as 0 (false) or 1 (true). + * The two methods below are the contract of conversion. + */ + def toDecimal(value: Any, dataType: DataType): Decimal = { + dataType match { + case _: NumericType | DateType | TimestampType => Decimal(value.toString) + case BooleanType => if (value.asInstanceOf[Boolean]) Decimal(1) else Decimal(0) + } + } + + def fromDecimal(dec: Decimal, dataType: DataType): Any = { + dataType match { + case BooleanType => dec.toLong == 1 + case DateType => dec.toInt + case TimestampType => dec.toLong + case ByteType => dec.toByte + case ShortType => dec.toShort + case IntegerType => dec.toInt + case LongType => dec.toLong + case FloatType => dec.toFloat + case DoubleType => dec.toDouble + case _: DecimalType => dec + } + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 7bd8e6511232f38dd03bef5a01df2088c60dac70..4b6b3b14d9ac8c37ea0d3ffbf514447715ab167f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -25,7 +25,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -301,30 +300,6 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging } } - /** - * For a SQL data type, its internal data type may be different from its external type. - * For DateType, its internal type is Int, and its external data type is Java Date type. - * The min/max values in ColumnStat are saved in their corresponding external type. - * - * @param attrDataType the column data type - * @param litValue the literal value - * @return a BigDecimal value - */ - def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = { - attrDataType match { - case DateType => - Some(DateTimeUtils.toJavaDate(litValue.toString.toInt)) - case TimestampType => - Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong)) - case _: DecimalType => - Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal) - case StringType | BinaryType => - None - case _ => - Some(litValue) - } - } - /** * Returns a percentage of rows meeting an equality (=) expression. * This method evaluates the equality predicate for all data types. @@ -356,12 +331,16 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging val statsRange = Range(colStat.min, colStat.max, attr.dataType) if (statsRange.contains(literal)) { if (update) { - // We update ColumnStat structure after apply this equality predicate. - // Set distinctCount to 1. Set nullCount to 0. - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attr.dataType, literal.value) - val newStats = colStat.copy(distinctCount = 1, min = newValue, - max = newValue, nullCount = 0) + // We update ColumnStat structure after apply this equality predicate: + // Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal + // value. + val newStats = attr.dataType match { + case StringType | BinaryType => + colStat.copy(distinctCount = 1, nullCount = 0) + case _ => + colStat.copy(distinctCount = 1, min = Some(literal.value), + max = Some(literal.value), nullCount = 0) + } colStatsMap(attr) = newStats } @@ -430,18 +409,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging return Some(0.0) } - // Need to save new min/max using the external type value of the literal - val newMax = convertBoundValue( - attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString))) - val newMin = convertBoundValue( - attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString))) - + val newMax = validQuerySet.maxBy(EstimationUtils.toDecimal(_, dataType)) + val newMin = validQuerySet.minBy(EstimationUtils.toDecimal(_, dataType)) // newNdv should not be greater than the old ndv. For example, column has only 2 values // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. newNdv = ndv.min(BigInt(validQuerySet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) + val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin), + max = Some(newMax), nullCount = 0) colStatsMap(attr) = newStats } @@ -478,8 +453,8 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging val colStat = colStatsMap(attr) val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] - val max = BigDecimal(statsRange.max) - val min = BigDecimal(statsRange.min) + val max = statsRange.max.toBigDecimal + val min = statsRange.min.toBigDecimal val ndv = BigDecimal(colStat.distinctCount) // determine the overlapping degree between predicate range and column's range @@ -540,8 +515,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging } if (update) { - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attr.dataType, literal.value) + val newValue = Some(literal.value) var newMax = colStat.max var newMin = colStat.min var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() @@ -606,14 +580,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging val colStatLeft = colStatsMap(attrLeft) val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType) .asInstanceOf[NumericRange] - val maxLeft = BigDecimal(statsRangeLeft.max) - val minLeft = BigDecimal(statsRangeLeft.min) + val maxLeft = statsRangeLeft.max + val minLeft = statsRangeLeft.min val colStatRight = colStatsMap(attrRight) val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType) .asInstanceOf[NumericRange] - val maxRight = BigDecimal(statsRangeRight.max) - val minRight = BigDecimal(statsRangeRight.min) + val maxRight = statsRangeRight.max + val minRight = statsRangeRight.min // determine the overlapping degree between predicate range and column's range val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 3d13967cb62a4358ebc2974e894d657752783b56..4ac5ba5689f82de2b775f6b93c51c93e2b954933 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -17,12 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import java.math.{BigDecimal => JDecimal} -import java.sql.{Date, Timestamp} - import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} +import org.apache.spark.sql.types._ /** Value range of a column. */ @@ -31,13 +27,10 @@ trait Range { } /** For simplicity we use decimal to unify operations of numeric ranges. */ -case class NumericRange(min: JDecimal, max: JDecimal) extends Range { +case class NumericRange(min: Decimal, max: Decimal) extends Range { override def contains(l: Literal): Boolean = { - val decimal = l.dataType match { - case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) - case _ => new JDecimal(l.value.toString) - } - min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0 + val lit = EstimationUtils.toDecimal(l.value, l.dataType) + min <= lit && max >= lit } } @@ -58,7 +51,10 @@ object Range { def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { case StringType | BinaryType => new DefaultRange() case _ if min.isEmpty || max.isEmpty => new NullRange() - case _ => toNumericRange(min.get, max.get, dataType) + case _ => + NumericRange( + min = EstimationUtils.toDecimal(min.get, dataType), + max = EstimationUtils.toDecimal(max.get, dataType)) } def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match { @@ -82,51 +78,11 @@ object Range { // binary/string types don't support intersecting. (None, None) case (n1: NumericRange, n2: NumericRange) => - val newRange = NumericRange(n1.min.max(n2.min), n1.max.min(n2.max)) - val (newMin, newMax) = fromNumericRange(newRange, dt) - (Some(newMin), Some(newMax)) + // Choose the maximum of two min values, and the minimum of two max values. + val newMin = if (n1.min <= n2.min) n2.min else n1.min + val newMax = if (n1.max <= n2.max) n1.max else n2.max + (Some(EstimationUtils.fromDecimal(newMin, dt)), + Some(EstimationUtils.fromDecimal(newMax, dt))) } } - - /** - * For simplicity we use decimal to unify operations of numeric types, the two methods below - * are the contract of conversion. - */ - private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = { - dataType match { - case _: NumericType => - NumericRange(new JDecimal(min.toString), new JDecimal(max.toString)) - case BooleanType => - val min1 = if (min.asInstanceOf[Boolean]) 1 else 0 - val max1 = if (max.asInstanceOf[Boolean]) 1 else 0 - NumericRange(new JDecimal(min1), new JDecimal(max1)) - case DateType => - val min1 = DateTimeUtils.fromJavaDate(min.asInstanceOf[Date]) - val max1 = DateTimeUtils.fromJavaDate(max.asInstanceOf[Date]) - NumericRange(new JDecimal(min1), new JDecimal(max1)) - case TimestampType => - val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp]) - val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp]) - NumericRange(new JDecimal(min1), new JDecimal(max1)) - } - } - - private def fromNumericRange(n: NumericRange, dataType: DataType): (Any, Any) = { - dataType match { - case _: IntegralType => - (n.min.longValue(), n.max.longValue()) - case FloatType | DoubleType => - (n.min.doubleValue(), n.max.doubleValue()) - case _: DecimalType => - (n.min, n.max) - case BooleanType => - (n.min.longValue() == 1, n.max.longValue() == 1) - case DateType => - (DateTimeUtils.toJavaDate(n.min.intValue()), DateTimeUtils.toJavaDate(n.max.intValue())) - case TimestampType => - (DateTimeUtils.toJavaTimestamp(n.min.longValue()), - DateTimeUtils.toJavaTimestamp(n.max.longValue())) - } - } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index cffb0d87392874a151e69215e84c29299036f138..a28447840ae097a081d4e26293ced451e3172432 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ /** @@ -45,15 +46,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { nullCount = 0, avgLen = 1, maxLen = 1) // column cdate has 10 values from 2017-01-01 through 2017-01-10. - val dMin = Date.valueOf("2017-01-01") - val dMax = Date.valueOf("2017-01-10") + val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) + val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10")) val attrDate = AttributeReference("cdate", DateType)() val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. - val decMin = new java.math.BigDecimal("0.200000000000000000") - val decMax = new java.math.BigDecimal("0.800000000000000000") + val decMin = Decimal("0.200000000000000000") + val decMax = Decimal("0.800000000000000000") val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) @@ -147,7 +148,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint < 3 OR null") { val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) - val m = Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)).stats(conf) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), Seq(attrInt -> colStatInt), @@ -341,6 +341,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 7) } + test("cbool IN (true)") { + validateEstimatedStats( + Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) + } + test("cbool = true") { validateEstimatedStats( Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), @@ -358,9 +366,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdate = cast('2017-01-02' AS DATE)") { - val d20170102 = Date.valueOf("2017-01-02") + val d20170102 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-02")) validateEstimatedStats( - Filter(EqualTo(attrDate, Literal(d20170102)), + Filter(EqualTo(attrDate, Literal(d20170102, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), nullCount = 0, avgLen = 4, maxLen = 4)), @@ -368,9 +376,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdate < cast('2017-01-03' AS DATE)") { - val d20170103 = Date.valueOf("2017-01-03") + val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) validateEstimatedStats( - Filter(LessThan(attrDate, Literal(d20170103)), + Filter(LessThan(attrDate, Literal(d20170103, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), nullCount = 0, avgLen = 4, maxLen = 4)), @@ -379,19 +387,19 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("""cdate IN ( cast('2017-01-03' AS DATE), cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") { - val d20170103 = Date.valueOf("2017-01-03") - val d20170104 = Date.valueOf("2017-01-04") - val d20170105 = Date.valueOf("2017-01-05") + val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) + val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04")) + val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05")) validateEstimatedStats( - Filter(In(attrDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), - childStatsTestPlan(Seq(attrDate), 10L)), + Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), + Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), nullCount = 0, avgLen = 4, maxLen = 4)), expectedRowCount = 3) } test("cdecimal = 0.400000000000000000") { - val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") + val dec_0_40 = Decimal("0.400000000000000000") validateEstimatedStats( Filter(EqualTo(attrDecimal, Literal(dec_0_40)), childStatsTestPlan(Seq(attrDecimal), 4L)), @@ -401,7 +409,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdecimal < 0.60 ") { - val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") + val dec_0_60 = Decimal("0.600000000000000000") validateEstimatedStats( Filter(LessThan(attrDecimal, Literal(dec_0_60)), childStatsTestPlan(Seq(attrDecimal), 4L)), @@ -532,7 +540,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint = cint3") { // no records qualify due to no overlap - val emptyColStats = Seq[(Attribute, ColumnStat)]() validateEstimatedStats( Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), Nil, // set to empty diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index f62df842fa50ad7fb5111d0a2cd5639fd5f6cf8f..2d6b6e8e21f34f6f17b73b4cf24542110b26b8e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{DateType, TimestampType, _} @@ -254,24 +255,24 @@ class JoinEstimationSuite extends StatsEstimationTestBase { test("test join keys of different types") { /** Columns in a table with only one row */ def genColumnData: mutable.LinkedHashMap[Attribute, ColumnStat] = { - val dec = new java.math.BigDecimal("1.000000000000000000") - val date = Date.valueOf("2016-05-08") - val timestamp = Timestamp.valueOf("2016-05-08 00:00:01") + val dec = Decimal("1.000000000000000000") + val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) + val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) mutable.LinkedHashMap[Attribute, ColumnStat]( AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 1, maxLen = 1), + min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 2, maxLen = 2), + min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2), AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16), AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index f408dc4153586409035fe095073a42bae696b651..a5c4d22a29386a27d0a97d672c5e2dc8c1b7e190 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -62,28 +63,28 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { } test("test row size estimation") { - val dec1 = new java.math.BigDecimal("1.000000000000000000") - val dec2 = new java.math.BigDecimal("8.000000000000000000") - val d1 = Date.valueOf("2016-05-08") - val d2 = Date.valueOf("2016-05-09") - val t1 = Timestamp.valueOf("2016-05-08 00:00:01") - val t2 = Timestamp.valueOf("2016-05-09 00:00:02") + val dec1 = Decimal("1.000000000000000000") + val dec2 = Decimal("8.000000000000000000") + val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) + val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-09")) + val t1 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) + val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02")) val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(2L), nullCount = 0, avgLen = 1, maxLen = 1), + min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(3L), nullCount = 0, avgLen = 2, maxLen = 2), + min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2), AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(4L), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2, min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0), max = Some(7.0), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2, min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16), AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index b89014ed8ef54f3af68289c3dec1ec95d029cf24..0d8db2ff5d5a03a931d70beb0de34789032b1502 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -73,10 +73,10 @@ case class AnalyzeColumnCommand( val relation = sparkSession.table(tableIdent).logicalPlan // Resolve the column names and dedup using AttributeSet val resolver = sparkSession.sessionState.conf.resolver - val attributesToAnalyze = AttributeSet(columnNames.map { col => + val attributesToAnalyze = columnNames.map { col => val exprOption = relation.output.find(attr => resolver(attr.name, col)) exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist.")) - }).toSeq + } // Make sure the column types are supported for stats gathering. attributesToAnalyze.foreach { attr => @@ -99,8 +99,8 @@ case class AnalyzeColumnCommand( val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head() val rowCount = statsRow.getLong(0) - val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - (expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1))) + val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => + (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1), attr)) }.toMap (rowCount, columnStats) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 1f547c5a2a8ff628bb251b2bb70c463b40e92423..ddc393c8da053060082e21cc6f385872d53fd577 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -26,6 +26,7 @@ import scala.util.Random import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} @@ -117,7 +118,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) stats.zip(df.schema).foreach { case ((k, v), field) => withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap) + val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) assert(roundtrip == Some(v)) } } @@ -201,17 +202,19 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils /** A mapping from column to the stats collected. */ protected val stats = mutable.LinkedHashMap( "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), - "cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1), - "cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2), - "cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4), + "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), + "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), + "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), - "cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4), - "cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16), + "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), + "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), "cstring" -> ColumnStat(2, None, None, 1, 3, 3), "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8) + "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)), + Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4), + "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)), + Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8) ) private val randomName = new Random(31) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 806f2be5faeb06c3f459cc2bc8926a74d8bddfff..8b0fdf49cefabb453283da2b0efe38af85802120 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -526,8 +526,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (stats.rowCount.isDefined) { statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } + val colNameTypeMap: Map[String, DataType] = + tableDefinition.schema.fields.map(f => (f.name, f.dataType)).toMap stats.colStats.foreach { case (colName, colStat) => - colStat.toMap.foreach { case (k, v) => + colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => statsProperties += (columnStatKeyPropName(colName, k) -> v) } }