diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 470307bd940ad9f54aeb2fc937e120f827ab3953..bc7e73ae1ba8773d8b985413aa378adfe37cc1b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, GenericInternalRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -53,219 +53,288 @@ private[columnar] sealed trait ColumnStats extends Serializable { /** * Gathers statistics information from `row(ordinal)`. */ - def gatherStats(row: InternalRow, ordinal: Int): Unit = { - if (row.isNullAt(ordinal)) { - nullCount += 1 - // 4 bytes for null position - sizeInBytes += 4 - } + def gatherStats(row: InternalRow, ordinal: Int): Unit + + /** + * Gathers statistics information on `null`. + */ + def gatherNullStats(): Unit = { + nullCount += 1 + // 4 bytes for null position + sizeInBytes += 4 count += 1 } /** - * Column statistics represented as a single row, currently including closed lower bound, closed + * Column statistics represented as an array, currently including closed lower bound, closed * upper bound and null count. */ - def collectedStatistics: GenericInternalRow + def collectedStatistics: Array[Any] } /** * A no-op ColumnStats only used for testing purposes. */ -private[columnar] class NoopColumnStats extends ColumnStats { - override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) +private[columnar] final class NoopColumnStats extends ColumnStats { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + count += 1 + } else { + gatherNullStats + } + } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) + override def collectedStatistics: Array[Any] = Array[Any](null, null, nullCount, count, 0L) } -private[columnar] class BooleanColumnStats extends ColumnStats { +private[columnar] final class BooleanColumnStats extends ColumnStats { protected var upper = false protected var lower = true override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getBoolean(ordinal) - if (value > upper) upper = value - if (value < lower) lower = value - sizeInBytes += BOOLEAN.defaultSize + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Boolean): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += BOOLEAN.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class ByteColumnStats extends ColumnStats { +private[columnar] final class ByteColumnStats extends ColumnStats { protected var upper = Byte.MinValue protected var lower = Byte.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getByte(ordinal) - if (value > upper) upper = value - if (value < lower) lower = value - sizeInBytes += BYTE.defaultSize + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Byte): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += BYTE.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class ShortColumnStats extends ColumnStats { +private[columnar] final class ShortColumnStats extends ColumnStats { protected var upper = Short.MinValue protected var lower = Short.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getShort(ordinal) - if (value > upper) upper = value - if (value < lower) lower = value - sizeInBytes += SHORT.defaultSize + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Short): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += SHORT.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class IntColumnStats extends ColumnStats { +private[columnar] final class IntColumnStats extends ColumnStats { protected var upper = Int.MinValue protected var lower = Int.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getInt(ordinal) - if (value > upper) upper = value - if (value < lower) lower = value - sizeInBytes += INT.defaultSize + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Int): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += INT.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class LongColumnStats extends ColumnStats { +private[columnar] final class LongColumnStats extends ColumnStats { protected var upper = Long.MinValue protected var lower = Long.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getLong(ordinal) - if (value > upper) upper = value - if (value < lower) lower = value - sizeInBytes += LONG.defaultSize + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Long): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += LONG.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class FloatColumnStats extends ColumnStats { +private[columnar] final class FloatColumnStats extends ColumnStats { protected var upper = Float.MinValue protected var lower = Float.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getFloat(ordinal) - if (value > upper) upper = value - if (value < lower) lower = value - sizeInBytes += FLOAT.defaultSize + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Float): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += FLOAT.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class DoubleColumnStats extends ColumnStats { +private[columnar] final class DoubleColumnStats extends ColumnStats { protected var upper = Double.MinValue protected var lower = Double.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getDouble(ordinal) - if (value > upper) upper = value - if (value < lower) lower = value - sizeInBytes += DOUBLE.defaultSize + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Double): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += DOUBLE.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class StringColumnStats extends ColumnStats { +private[columnar] final class StringColumnStats extends ColumnStats { protected var upper: UTF8String = null protected var lower: UTF8String = null override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getUTF8String(ordinal) - if (upper == null || value.compareTo(upper) > 0) upper = value.clone() - if (lower == null || value.compareTo(lower) < 0) lower = value.clone() - sizeInBytes += STRING.actualSize(row, ordinal) + val size = STRING.actualSize(row, ordinal) + gatherValueStats(value, size) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: UTF8String, size: Int): Unit = { + if (upper == null || value.compareTo(upper) > 0) upper = value.clone() + if (lower == null || value.compareTo(lower) < 0) lower = value.clone() + sizeInBytes += size + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class BinaryColumnStats extends ColumnStats { +private[columnar] final class BinaryColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - sizeInBytes += BINARY.actualSize(row, ordinal) + val size = BINARY.actualSize(row, ordinal) + sizeInBytes += size + count += 1 + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) + override def collectedStatistics: Array[Any] = + Array[Any](null, null, nullCount, count, sizeInBytes) } -private[columnar] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { +private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { def this(dt: DecimalType) = this(dt.precision, dt.scale) protected var upper: Decimal = null protected var lower: Decimal = null override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getDecimal(ordinal, precision, scale) - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value // TODO: this is not right for DecimalType with precision > 18 - sizeInBytes += 8 + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Decimal): Unit = { + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + sizeInBytes += 8 + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class ObjectColumnStats(dataType: DataType) extends ColumnStats { +private[columnar] final class ObjectColumnStats(dataType: DataType) extends ColumnStats { val columnType = ColumnType(dataType) override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - sizeInBytes += columnType.actualSize(row, ordinal) + val size = columnType.actualSize(row, ordinal) + sizeInBytes += size + count += 1 + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) + override def collectedStatistics: Array[Any] = + Array[Any](null, null, nullCount, count, sizeInBytes) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 0a9f3e799990f16df9afca8bed60dc8749418c62..3486a6bce81800e92409ffadb55299f20bde690a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -123,8 +123,8 @@ case class InMemoryRelation( batchStats.add(totalSize) - val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) - .flatMap(_.values)) + val stats = InternalRow.fromSeq( + columnBuilders.flatMap(_.columnStats.collectedStatistics)) CachedBatch(rowCount, columnBuilders.map { builder => JavaUtils.bufferToArray(builder.build()) }, stats) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index b2d04f7c5a6e36ee93b6766a02771dba745a72b9..d4e7e362c6c8c7e7d23016f6905f8b486974f4f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -18,33 +18,29 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) - testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) - testColumnStats(classOf[DoubleColumnStats], DOUBLE, - createRow(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) - testDecimalColumnStats(createRow(null, null, 0)) - - def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, Array(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, Array(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, Array(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0)) + testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0)) + testDecimalColumnStats(Array(null, null, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: GenericInternalRow): Unit = { + initialStatistics: Array[Any]): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { + columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => assert(actual === expected) } } @@ -60,11 +56,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) - assertResult(10, "Wrong null count")(stats.values(2)) - assertResult(20, "Wrong row count")(stats.values(3)) - assertResult(stats.values(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) + assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum @@ -73,14 +69,14 @@ class ColumnStatsSuite extends SparkFunSuite { } def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( - initialStatistics: GenericInternalRow): Unit = { + initialStatistics: Array[Any]): Unit = { val columnStatsName = classOf[DecimalColumnStats].getSimpleName val columnType = COMPACT_DECIMAL(15, 10) test(s"$columnStatsName: empty") { val columnStats = new DecimalColumnStats(15, 10) - columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { + columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => assert(actual === expected) } } @@ -96,11 +92,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) - assertResult(10, "Wrong null count")(stats.values(2)) - assertResult(20, "Wrong row count")(stats.values(3)) - assertResult(stats.values(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) + assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum