Skip to content
Snippets Groups Projects
Commit bf66335a authored by Wang Gengliang's avatar Wang Gengliang Committed by Reynold Xin
Browse files

[SPARK-21323][SQL] Rename plans.logical.statsEstimation.Range to ValueInterval

## What changes were proposed in this pull request?

Rename org.apache.spark.sql.catalyst.plans.logical.statsEstimation.Range to ValueInterval.
The current naming is identical to logical operator "range".
Refactoring it to ValueInterval is more accurate.

## How was this patch tested?

unit test

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Wang Gengliang <ltnwgl@gmail.com>

Closes #18549 from gengliangwang/ValueInterval.
parent 48e44b24
No related branches found
No related tags found
No related merge requests found
...@@ -316,8 +316,8 @@ case class FilterEstimation(plan: Filter) extends Logging { ...@@ -316,8 +316,8 @@ case class FilterEstimation(plan: Filter) extends Logging {
// decide if the value is in [min, max] of the column. // decide if the value is in [min, max] of the column.
// We currently don't store min/max for binary/string type. // We currently don't store min/max for binary/string type.
// Hence, we assume it is in boundary for binary/string type. // Hence, we assume it is in boundary for binary/string type.
val statsRange = Range(colStat.min, colStat.max, attr.dataType) val statsInterval = ValueInterval(colStat.min, colStat.max, attr.dataType)
if (statsRange.contains(literal)) { if (statsInterval.contains(literal)) {
if (update) { if (update) {
// We update ColumnStat structure after apply this equality predicate: // We update ColumnStat structure after apply this equality predicate:
// Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal // Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal
...@@ -388,9 +388,10 @@ case class FilterEstimation(plan: Filter) extends Logging { ...@@ -388,9 +388,10 @@ case class FilterEstimation(plan: Filter) extends Logging {
// use [min, max] to filter the original hSet // use [min, max] to filter the original hSet
dataType match { dataType match {
case _: NumericType | BooleanType | DateType | TimestampType => case _: NumericType | BooleanType | DateType | TimestampType =>
val statsRange = Range(colStat.min, colStat.max, dataType).asInstanceOf[NumericRange] val statsInterval =
ValueInterval(colStat.min, colStat.max, dataType).asInstanceOf[NumericValueInterval]
val validQuerySet = hSet.filter { v => val validQuerySet = hSet.filter { v =>
v != null && statsRange.contains(Literal(v, dataType)) v != null && statsInterval.contains(Literal(v, dataType))
} }
if (validQuerySet.isEmpty) { if (validQuerySet.isEmpty) {
...@@ -440,12 +441,13 @@ case class FilterEstimation(plan: Filter) extends Logging { ...@@ -440,12 +441,13 @@ case class FilterEstimation(plan: Filter) extends Logging {
update: Boolean): Option[BigDecimal] = { update: Boolean): Option[BigDecimal] = {
val colStat = colStatsMap(attr) val colStat = colStatsMap(attr)
val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] val statsInterval =
val max = statsRange.max.toBigDecimal ValueInterval(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericValueInterval]
val min = statsRange.min.toBigDecimal val max = statsInterval.max.toBigDecimal
val min = statsInterval.min.toBigDecimal
val ndv = BigDecimal(colStat.distinctCount) val ndv = BigDecimal(colStat.distinctCount)
// determine the overlapping degree between predicate range and column's range // determine the overlapping degree between predicate interval and column's interval
val numericLiteral = if (literal.dataType == BooleanType) { val numericLiteral = if (literal.dataType == BooleanType) {
if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0) if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0)
} else { } else {
...@@ -566,18 +568,18 @@ case class FilterEstimation(plan: Filter) extends Logging { ...@@ -566,18 +568,18 @@ case class FilterEstimation(plan: Filter) extends Logging {
} }
val colStatLeft = colStatsMap(attrLeft) val colStatLeft = colStatsMap(attrLeft)
val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType) val statsIntervalLeft = ValueInterval(colStatLeft.min, colStatLeft.max, attrLeft.dataType)
.asInstanceOf[NumericRange] .asInstanceOf[NumericValueInterval]
val maxLeft = statsRangeLeft.max val maxLeft = statsIntervalLeft.max
val minLeft = statsRangeLeft.min val minLeft = statsIntervalLeft.min
val colStatRight = colStatsMap(attrRight) val colStatRight = colStatsMap(attrRight)
val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType) val statsIntervalRight = ValueInterval(colStatRight.min, colStatRight.max, attrRight.dataType)
.asInstanceOf[NumericRange] .asInstanceOf[NumericValueInterval]
val maxRight = statsRangeRight.max val maxRight = statsIntervalRight.max
val minRight = statsRangeRight.min val minRight = statsIntervalRight.min
// determine the overlapping degree between predicate range and column's range // determine the overlapping degree between predicate interval and column's interval
val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0)
val (noOverlap: Boolean, completeOverlap: Boolean) = op match { val (noOverlap: Boolean, completeOverlap: Boolean) = op match {
// Left < Right or Left <= Right // Left < Right or Left <= Right
......
...@@ -175,9 +175,9 @@ case class InnerOuterEstimation(join: Join) extends Logging { ...@@ -175,9 +175,9 @@ case class InnerOuterEstimation(join: Join) extends Logging {
// Check if the two sides are disjoint // Check if the two sides are disjoint
val leftKeyStats = leftStats.attributeStats(leftKey) val leftKeyStats = leftStats.attributeStats(leftKey)
val rightKeyStats = rightStats.attributeStats(rightKey) val rightKeyStats = rightStats.attributeStats(rightKey)
val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType)
val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType)
if (Range.isIntersected(lRange, rRange)) { if (ValueInterval.isIntersected(lInterval, rInterval)) {
// Get the largest ndv among pairs of join keys // Get the largest ndv among pairs of join keys
val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount) val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount)
if (maxNdv > ndvDenom) ndvDenom = maxNdv if (maxNdv > ndvDenom) ndvDenom = maxNdv
...@@ -239,16 +239,16 @@ case class InnerOuterEstimation(join: Join) extends Logging { ...@@ -239,16 +239,16 @@ case class InnerOuterEstimation(join: Join) extends Logging {
joinKeyPairs.foreach { case (leftKey, rightKey) => joinKeyPairs.foreach { case (leftKey, rightKey) =>
val leftKeyStats = leftStats.attributeStats(leftKey) val leftKeyStats = leftStats.attributeStats(leftKey)
val rightKeyStats = rightStats.attributeStats(rightKey) val rightKeyStats = rightStats.attributeStats(rightKey)
val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType)
val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType)
// When we reach here, join selectivity is not zero, so each pair of join keys should be // When we reach here, join selectivity is not zero, so each pair of join keys should be
// intersected. // intersected.
assert(Range.isIntersected(lRange, rRange)) assert(ValueInterval.isIntersected(lInterval, rInterval))
// Update intersected column stats // Update intersected column stats
assert(leftKey.dataType.sameType(rightKey.dataType)) assert(leftKey.dataType.sameType(rightKey.dataType))
val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount)
val (newMin, newMax) = Range.intersect(lRange, rRange, leftKey.dataType) val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType)
val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen) val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen)
val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2 val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2
val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen)
......
...@@ -22,12 +22,12 @@ import org.apache.spark.sql.types._ ...@@ -22,12 +22,12 @@ import org.apache.spark.sql.types._
/** Value range of a column. */ /** Value range of a column. */
trait Range { trait ValueInterval {
def contains(l: Literal): Boolean def contains(l: Literal): Boolean
} }
/** For simplicity we use decimal to unify operations of numeric ranges. */ /** For simplicity we use decimal to unify operations of numeric intervals. */
case class NumericRange(min: Decimal, max: Decimal) extends Range { case class NumericValueInterval(min: Decimal, max: Decimal) extends ValueInterval {
override def contains(l: Literal): Boolean = { override def contains(l: Literal): Boolean = {
val lit = EstimationUtils.toDecimal(l.value, l.dataType) val lit = EstimationUtils.toDecimal(l.value, l.dataType)
min <= lit && max >= lit min <= lit && max >= lit
...@@ -38,46 +38,49 @@ case class NumericRange(min: Decimal, max: Decimal) extends Range { ...@@ -38,46 +38,49 @@ case class NumericRange(min: Decimal, max: Decimal) extends Range {
* This version of Spark does not have min/max for binary/string types, we define their default * This version of Spark does not have min/max for binary/string types, we define their default
* behaviors by this class. * behaviors by this class.
*/ */
class DefaultRange extends Range { class DefaultValueInterval extends ValueInterval {
override def contains(l: Literal): Boolean = true override def contains(l: Literal): Boolean = true
} }
/** This is for columns with only null values. */ /** This is for columns with only null values. */
class NullRange extends Range { class NullValueInterval extends ValueInterval {
override def contains(l: Literal): Boolean = false override def contains(l: Literal): Boolean = false
} }
object Range { object ValueInterval {
def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { def apply(
case StringType | BinaryType => new DefaultRange() min: Option[Any],
case _ if min.isEmpty || max.isEmpty => new NullRange() max: Option[Any],
dataType: DataType): ValueInterval = dataType match {
case StringType | BinaryType => new DefaultValueInterval()
case _ if min.isEmpty || max.isEmpty => new NullValueInterval()
case _ => case _ =>
NumericRange( NumericValueInterval(
min = EstimationUtils.toDecimal(min.get, dataType), min = EstimationUtils.toDecimal(min.get, dataType),
max = EstimationUtils.toDecimal(max.get, dataType)) max = EstimationUtils.toDecimal(max.get, dataType))
} }
def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match { def isIntersected(r1: ValueInterval, r2: ValueInterval): Boolean = (r1, r2) match {
case (_, _: DefaultRange) | (_: DefaultRange, _) => case (_, _: DefaultValueInterval) | (_: DefaultValueInterval, _) =>
// The DefaultRange represents string/binary types which do not have max/min stats, // The DefaultValueInterval represents string/binary types which do not have max/min stats,
// we assume they are intersected to be conservative on estimation // we assume they are intersected to be conservative on estimation
true true
case (_, _: NullRange) | (_: NullRange, _) => case (_, _: NullValueInterval) | (_: NullValueInterval, _) =>
false false
case (n1: NumericRange, n2: NumericRange) => case (n1: NumericValueInterval, n2: NumericValueInterval) =>
n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0
} }
/** /**
* Intersected results of two ranges. This is only for two overlapped ranges. * Intersected results of two intervals. This is only for two overlapped intervals.
* The outputs are the intersected min/max values. * The outputs are the intersected min/max values.
*/ */
def intersect(r1: Range, r2: Range, dt: DataType): (Option[Any], Option[Any]) = { def intersect(r1: ValueInterval, r2: ValueInterval, dt: DataType): (Option[Any], Option[Any]) = {
(r1, r2) match { (r1, r2) match {
case (_, _: DefaultRange) | (_: DefaultRange, _) => case (_, _: DefaultValueInterval) | (_: DefaultValueInterval, _) =>
// binary/string types don't support intersecting. // binary/string types don't support intersecting.
(None, None) (None, None)
case (n1: NumericRange, n2: NumericRange) => case (n1: NumericValueInterval, n2: NumericValueInterval) =>
// Choose the maximum of two min values, and the minimum of two max values. // Choose the maximum of two min values, and the minimum of two max values.
val newMin = if (n1.min <= n2.min) n2.min else n1.min val newMin = if (n1.min <= n2.min) n2.min else n1.min
val newMax = if (n1.max <= n2.max) n1.max else n2.max val newMax = if (n1.max <= n2.max) n1.max else n2.max
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment