diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 79865609cb6474dfb63a800e373066efaa5a78e9..465fbab5716ac8f94d31a9f80b39822c93b10293 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -194,11 +194,12 @@ object ColumnStat extends Logging { val numNonNulls = if (col.nullable) Count(col) else Count(one) val ndv = Least(Seq(HyperLogLogPlusPlus(col, relativeSD), numNonNulls)) val numNulls = Subtract(Count(one), numNonNulls) + val defaultSize = Literal(col.dataType.defaultSize, LongType) def fixedLenTypeStruct(castType: DataType) = { // For fixed width types, avg size should be the same as max size. - val avgSize = Literal(col.dataType.defaultSize, LongType) - struct(ndv, Cast(Min(col), castType), Cast(Max(col), castType), numNulls, avgSize, avgSize) + struct(ndv, Cast(Min(col), castType), Cast(Max(col), castType), numNulls, defaultSize, + defaultSize) } col.dataType match { @@ -213,7 +214,9 @@ object ColumnStat extends Logging { val nullLit = Literal(null, col.dataType) struct( ndv, nullLit, nullLit, numNulls, - Ceil(Average(Length(col))), Cast(Max(Length(col)), LongType)) + // Set avg/max size to default size if all the values are null or there is no value. + Coalesce(Seq(Ceil(Average(Length(col))), defaultSize)), + Coalesce(Seq(Cast(Max(Length(col)), LongType), defaultSize))) case _ => throw new AnalysisException("Analyzing column statistics is not supported for column " + s"${col.name} of data type: ${col.dataType}.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 1fcccd061079e2791fb428a9f4c7456a25a2d6d0..07408491953caa1d78cbaf1707b2dd3469c54afb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -21,6 +21,7 @@ import java.{lang => jl} import java.sql.{Date, Timestamp} import scala.collection.mutable +import scala.util.Random import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical._ @@ -133,6 +134,40 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } + test("column stats round trip serialization") { + // Make sure we serialize and then deserialize and we will get the result data + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + stats.zip(df.schema).foreach { case ((k, v), field) => + withClue(s"column $k with type ${field.dataType}") { + val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap) + assert(roundtrip == Some(v)) + } + } + } + + test("analyze column command - result verification") { + // (data.head.productArity - 1) because the last column does not support stats collection. + assert(stats.size == data.head.productArity - 1) + val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + checkColStats(df, stats) + } + + test("column stats collection for null columns") { + val dataTypes: Seq[(DataType, Int)] = Seq( + BooleanType, ByteType, ShortType, IntegerType, LongType, + DoubleType, FloatType, DecimalType.SYSTEM_DEFAULT, + StringType, BinaryType, DateType, TimestampType + ).zipWithIndex + + val df = sql("select " + dataTypes.map { case (tpe, idx) => + s"cast(null as ${tpe.sql}) as col$idx" + }.mkString(", ")) + + val expectedColStats = dataTypes.map { case (tpe, idx) => + (s"col$idx", ColumnStat(0, None, None, 1, tpe.defaultSize.toLong, tpe.defaultSize.toLong)) + } + checkColStats(df, mutable.LinkedHashMap(expectedColStats: _*)) + } } @@ -141,7 +176,6 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared * when using the Hive external catalog) as well as in the sql/core module. */ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils { - import testImplicits._ private val dec1 = new java.math.BigDecimal("1.000000000000000000") private val dec2 = new java.math.BigDecimal("8.000000000000000000") @@ -180,35 +214,28 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8) ) - test("column stats round trip serialization") { - // Make sure we serialize and then deserialize and we will get the result data - val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) - stats.zip(df.schema).foreach { case ((k, v), field) => - withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap) - assert(roundtrip == Some(v)) - } - } - } - - test("analyze column command - result verification") { - val tableName = "column_stats_test2" - // (data.head.productArity - 1) because the last column does not support stats collection. - assert(stats.size == data.head.productArity - 1) - val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) + private val randomName = new Random(31) + /** + * Compute column stats for the given DataFrame and compare it with colStats. + */ + def checkColStats( + df: DataFrame, + colStats: mutable.LinkedHashMap[String, ColumnStat]): Unit = { + val tableName = "column_stats_test_" + randomName.nextInt(1000) withTable(tableName) { df.write.saveAsTable(tableName) // Collect statistics - sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) + sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + + colStats.keys.mkString(", ")) // Validate statistics val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) assert(table.stats.isDefined) - assert(table.stats.get.colStats.size == stats.size) + assert(table.stats.get.colStats.size == colStats.size) - stats.foreach { case (k, v) => + colStats.foreach { case (k, v) => withClue(s"column $k") { assert(table.stats.get.colStats(k) == v) }