Skip to content
Snippets Groups Projects
Commit 833c8d41 authored by Kazuaki Ishizaki's avatar Kazuaki Ishizaki Committed by Wenchen Fan
Browse files

[SPARK-20770][SQL] Improve ColumnStats

## What changes were proposed in this pull request?

This PR improves the implementation of `ColumnStats` by using the following appoaches.

1. Declare subclasses of `ColumnStats` as `final`
2. Remove unnecessary call of `row.isNullAt(ordinal)`
3. Remove the dependency on `GenericInternalRow`

For 1., this declaration encourages method inlining and other optimizations of JIT compiler
For 2., in `gatherStats()`, while previous code in subclasses of `ColumnStats` always calls `row.isNullAt()` twice, the PR just calls `row.isNullAt()` only once.
For 3., `collectedStatistics()` returns `Array[Any]` instead of `GenericInternalRow`. This removes the dependency of unnecessary package and reduces the number of allocations of `GenericInternalRow`.

In addition to that, in the future, `gatherValueStats()`, which is specialized for each data type, can be effectively called from the generated code without using generic data structure `InternalRow`.

## How was this patch tested?

Tested by existing test suite

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #18002 from kiszk/SPARK-20770.
parent 3c9eef35
No related branches found
No related tags found
No related merge requests found
......@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.columnar
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, GenericInternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
......@@ -53,219 +53,288 @@ private[columnar] sealed trait ColumnStats extends Serializable {
/**
* Gathers statistics information from `row(ordinal)`.
*/
def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (row.isNullAt(ordinal)) {
nullCount += 1
// 4 bytes for null position
sizeInBytes += 4
}
def gatherStats(row: InternalRow, ordinal: Int): Unit
/**
* Gathers statistics information on `null`.
*/
def gatherNullStats(): Unit = {
nullCount += 1
// 4 bytes for null position
sizeInBytes += 4
count += 1
}
/**
* Column statistics represented as a single row, currently including closed lower bound, closed
* Column statistics represented as an array, currently including closed lower bound, closed
* upper bound and null count.
*/
def collectedStatistics: GenericInternalRow
def collectedStatistics: Array[Any]
}
/**
* A no-op ColumnStats only used for testing purposes.
*/
private[columnar] class NoopColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal)
private[columnar] final class NoopColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
count += 1
} else {
gatherNullStats
}
}
override def collectedStatistics: GenericInternalRow =
new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L))
override def collectedStatistics: Array[Any] = Array[Any](null, null, nullCount, count, 0L)
}
private[columnar] class BooleanColumnStats extends ColumnStats {
private[columnar] final class BooleanColumnStats extends ColumnStats {
protected var upper = false
protected var lower = true
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getBoolean(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += BOOLEAN.defaultSize
gatherValueStats(value)
} else {
gatherNullStats
}
}
override def collectedStatistics: GenericInternalRow =
new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
def gatherValueStats(value: Boolean): Unit = {
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += BOOLEAN.defaultSize
count += 1
}
override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
private[columnar] class ByteColumnStats extends ColumnStats {
private[columnar] final class ByteColumnStats extends ColumnStats {
protected var upper = Byte.MinValue
protected var lower = Byte.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getByte(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += BYTE.defaultSize
gatherValueStats(value)
} else {
gatherNullStats
}
}
override def collectedStatistics: GenericInternalRow =
new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
def gatherValueStats(value: Byte): Unit = {
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += BYTE.defaultSize
count += 1
}
override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
private[columnar] class ShortColumnStats extends ColumnStats {
private[columnar] final class ShortColumnStats extends ColumnStats {
protected var upper = Short.MinValue
protected var lower = Short.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getShort(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += SHORT.defaultSize
gatherValueStats(value)
} else {
gatherNullStats
}
}
override def collectedStatistics: GenericInternalRow =
new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
def gatherValueStats(value: Short): Unit = {
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += SHORT.defaultSize
count += 1
}
override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
private[columnar] class IntColumnStats extends ColumnStats {
private[columnar] final class IntColumnStats extends ColumnStats {
protected var upper = Int.MinValue
protected var lower = Int.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getInt(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += INT.defaultSize
gatherValueStats(value)
} else {
gatherNullStats
}
}
override def collectedStatistics: GenericInternalRow =
new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
def gatherValueStats(value: Int): Unit = {
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += INT.defaultSize
count += 1
}
override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
private[columnar] class LongColumnStats extends ColumnStats {
private[columnar] final class LongColumnStats extends ColumnStats {
protected var upper = Long.MinValue
protected var lower = Long.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getLong(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += LONG.defaultSize
gatherValueStats(value)
} else {
gatherNullStats
}
}
override def collectedStatistics: GenericInternalRow =
new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
def gatherValueStats(value: Long): Unit = {
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += LONG.defaultSize
count += 1
}
override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
private[columnar] class FloatColumnStats extends ColumnStats {
private[columnar] final class FloatColumnStats extends ColumnStats {
protected var upper = Float.MinValue
protected var lower = Float.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getFloat(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += FLOAT.defaultSize
gatherValueStats(value)
} else {
gatherNullStats
}
}
override def collectedStatistics: GenericInternalRow =
new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
def gatherValueStats(value: Float): Unit = {
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += FLOAT.defaultSize
count += 1
}
override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
private[columnar] class DoubleColumnStats extends ColumnStats {
private[columnar] final class DoubleColumnStats extends ColumnStats {
protected var upper = Double.MinValue
protected var lower = Double.MaxValue
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getDouble(ordinal)
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += DOUBLE.defaultSize
gatherValueStats(value)
} else {
gatherNullStats
}
}
override def collectedStatistics: GenericInternalRow =
new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
def gatherValueStats(value: Double): Unit = {
if (value > upper) upper = value
if (value < lower) lower = value
sizeInBytes += DOUBLE.defaultSize
count += 1
}
override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
private[columnar] class StringColumnStats extends ColumnStats {
private[columnar] final class StringColumnStats extends ColumnStats {
protected var upper: UTF8String = null
protected var lower: UTF8String = null
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getUTF8String(ordinal)
if (upper == null || value.compareTo(upper) > 0) upper = value.clone()
if (lower == null || value.compareTo(lower) < 0) lower = value.clone()
sizeInBytes += STRING.actualSize(row, ordinal)
val size = STRING.actualSize(row, ordinal)
gatherValueStats(value, size)
} else {
gatherNullStats
}
}
override def collectedStatistics: GenericInternalRow =
new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
def gatherValueStats(value: UTF8String, size: Int): Unit = {
if (upper == null || value.compareTo(upper) > 0) upper = value.clone()
if (lower == null || value.compareTo(lower) < 0) lower = value.clone()
sizeInBytes += size
count += 1
}
override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
private[columnar] class BinaryColumnStats extends ColumnStats {
private[columnar] final class BinaryColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
sizeInBytes += BINARY.actualSize(row, ordinal)
val size = BINARY.actualSize(row, ordinal)
sizeInBytes += size
count += 1
} else {
gatherNullStats
}
}
override def collectedStatistics: GenericInternalRow =
new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
override def collectedStatistics: Array[Any] =
Array[Any](null, null, nullCount, count, sizeInBytes)
}
private[columnar] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats {
def this(dt: DecimalType) = this(dt.precision, dt.scale)
protected var upper: Decimal = null
protected var lower: Decimal = null
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row.getDecimal(ordinal, precision, scale)
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
// TODO: this is not right for DecimalType with precision > 18
sizeInBytes += 8
gatherValueStats(value)
} else {
gatherNullStats
}
}
override def collectedStatistics: GenericInternalRow =
new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes))
def gatherValueStats(value: Decimal): Unit = {
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
sizeInBytes += 8
count += 1
}
override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}
private[columnar] class ObjectColumnStats(dataType: DataType) extends ColumnStats {
private[columnar] final class ObjectColumnStats(dataType: DataType) extends ColumnStats {
val columnType = ColumnType(dataType)
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
sizeInBytes += columnType.actualSize(row, ordinal)
val size = columnType.actualSize(row, ordinal)
sizeInBytes += size
count += 1
} else {
gatherNullStats
}
}
override def collectedStatistics: GenericInternalRow =
new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes))
override def collectedStatistics: Array[Any] =
Array[Any](null, null, nullCount, count, sizeInBytes)
}
......@@ -123,8 +123,8 @@ case class InMemoryRelation(
batchStats.add(totalSize)
val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics)
.flatMap(_.values))
val stats = InternalRow.fromSeq(
columnBuilders.flatMap(_.columnStats.collectedStatistics))
CachedBatch(rowCount, columnBuilders.map { builder =>
JavaUtils.bufferToArray(builder.build())
}, stats)
......
......@@ -18,33 +18,29 @@
package org.apache.spark.sql.execution.columnar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types._
class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0))
testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0))
testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0))
testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0))
testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0))
testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0))
testColumnStats(classOf[DoubleColumnStats], DOUBLE,
createRow(Double.MaxValue, Double.MinValue, 0))
testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0))
testDecimalColumnStats(createRow(null, null, 0))
def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray)
testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0))
testColumnStats(classOf[ByteColumnStats], BYTE, Array(Byte.MaxValue, Byte.MinValue, 0))
testColumnStats(classOf[ShortColumnStats], SHORT, Array(Short.MaxValue, Short.MinValue, 0))
testColumnStats(classOf[IntColumnStats], INT, Array(Int.MaxValue, Int.MinValue, 0))
testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0))
testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0))
testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0))
testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0))
testDecimalColumnStats(Array(null, null, 0))
def testColumnStats[T <: AtomicType, U <: ColumnStats](
columnStatsClass: Class[U],
columnType: NativeColumnType[T],
initialStatistics: GenericInternalRow): Unit = {
initialStatistics: Array[Any]): Unit = {
val columnStatsName = columnStatsClass.getSimpleName
test(s"$columnStatsName: empty") {
val columnStats = columnStatsClass.newInstance()
columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
columnStats.collectedStatistics.zip(initialStatistics).foreach {
case (actual, expected) => assert(actual === expected)
}
}
......@@ -60,11 +56,11 @@ class ColumnStatsSuite extends SparkFunSuite {
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
val stats = columnStats.collectedStatistics
assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0))
assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1))
assertResult(10, "Wrong null count")(stats.values(2))
assertResult(20, "Wrong row count")(stats.values(3))
assertResult(stats.values(4), "Wrong size in bytes") {
assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
assertResult(10, "Wrong null count")(stats(2))
assertResult(20, "Wrong row count")(stats(3))
assertResult(stats(4), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
......@@ -73,14 +69,14 @@ class ColumnStatsSuite extends SparkFunSuite {
}
def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](
initialStatistics: GenericInternalRow): Unit = {
initialStatistics: Array[Any]): Unit = {
val columnStatsName = classOf[DecimalColumnStats].getSimpleName
val columnType = COMPACT_DECIMAL(15, 10)
test(s"$columnStatsName: empty") {
val columnStats = new DecimalColumnStats(15, 10)
columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach {
columnStats.collectedStatistics.zip(initialStatistics).foreach {
case (actual, expected) => assert(actual === expected)
}
}
......@@ -96,11 +92,11 @@ class ColumnStatsSuite extends SparkFunSuite {
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]]
val stats = columnStats.collectedStatistics
assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0))
assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1))
assertResult(10, "Wrong null count")(stats.values(2))
assertResult(20, "Wrong row count")(stats.values(3))
assertResult(stats.values(4), "Wrong size in bytes") {
assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
assertResult(10, "Wrong null count")(stats(2))
assertResult(20, "Wrong row count")(stats(3))
assertResult(stats(4), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment