diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 5a3bee7b9e4493df56ef357f1829587ca76160d6..e13db85c7a76e4db837b2791ee4ffa07bb404d39 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -316,8 +316,8 @@ case class FilterEstimation(plan: Filter) extends Logging { // decide if the value is in [min, max] of the column. // We currently don't store min/max for binary/string type. // Hence, we assume it is in boundary for binary/string type. - val statsRange = Range(colStat.min, colStat.max, attr.dataType) - if (statsRange.contains(literal)) { + val statsInterval = ValueInterval(colStat.min, colStat.max, attr.dataType) + if (statsInterval.contains(literal)) { if (update) { // We update ColumnStat structure after apply this equality predicate: // Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal @@ -388,9 +388,10 @@ case class FilterEstimation(plan: Filter) extends Logging { // use [min, max] to filter the original hSet dataType match { case _: NumericType | BooleanType | DateType | TimestampType => - val statsRange = Range(colStat.min, colStat.max, dataType).asInstanceOf[NumericRange] + val statsInterval = + ValueInterval(colStat.min, colStat.max, dataType).asInstanceOf[NumericValueInterval] val validQuerySet = hSet.filter { v => - v != null && statsRange.contains(Literal(v, dataType)) + v != null && statsInterval.contains(Literal(v, dataType)) } if (validQuerySet.isEmpty) { @@ -440,12 +441,13 @@ case class FilterEstimation(plan: Filter) extends Logging { update: Boolean): Option[BigDecimal] = { val colStat = colStatsMap(attr) - val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] - val max = statsRange.max.toBigDecimal - val min = statsRange.min.toBigDecimal + val statsInterval = + ValueInterval(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericValueInterval] + val max = statsInterval.max.toBigDecimal + val min = statsInterval.min.toBigDecimal val ndv = BigDecimal(colStat.distinctCount) - // determine the overlapping degree between predicate range and column's range + // determine the overlapping degree between predicate interval and column's interval val numericLiteral = if (literal.dataType == BooleanType) { if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0) } else { @@ -566,18 +568,18 @@ case class FilterEstimation(plan: Filter) extends Logging { } val colStatLeft = colStatsMap(attrLeft) - val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType) - .asInstanceOf[NumericRange] - val maxLeft = statsRangeLeft.max - val minLeft = statsRangeLeft.min + val statsIntervalLeft = ValueInterval(colStatLeft.min, colStatLeft.max, attrLeft.dataType) + .asInstanceOf[NumericValueInterval] + val maxLeft = statsIntervalLeft.max + val minLeft = statsIntervalLeft.min val colStatRight = colStatsMap(attrRight) - val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType) - .asInstanceOf[NumericRange] - val maxRight = statsRangeRight.max - val minRight = statsRangeRight.min + val statsIntervalRight = ValueInterval(colStatRight.min, colStatRight.max, attrRight.dataType) + .asInstanceOf[NumericValueInterval] + val maxRight = statsIntervalRight.max + val minRight = statsIntervalRight.min - // determine the overlapping degree between predicate range and column's range + // determine the overlapping degree between predicate interval and column's interval val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) val (noOverlap: Boolean, completeOverlap: Boolean) = op match { // Left < Right or Left <= Right diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index f48196997a24d382c7c06ea2c279f7fe4d9bc289..dcbe36da91dfc264b2f1522133871d737ff3815f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -175,9 +175,9 @@ case class InnerOuterEstimation(join: Join) extends Logging { // Check if the two sides are disjoint val leftKeyStats = leftStats.attributeStats(leftKey) val rightKeyStats = rightStats.attributeStats(rightKey) - val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) - if (Range.isIntersected(lRange, rRange)) { + val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) + val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + if (ValueInterval.isIntersected(lInterval, rInterval)) { // Get the largest ndv among pairs of join keys val maxNdv = leftKeyStats.distinctCount.max(rightKeyStats.distinctCount) if (maxNdv > ndvDenom) ndvDenom = maxNdv @@ -239,16 +239,16 @@ case class InnerOuterEstimation(join: Join) extends Logging { joinKeyPairs.foreach { case (leftKey, rightKey) => val leftKeyStats = leftStats.attributeStats(leftKey) val rightKeyStats = rightStats.attributeStats(rightKey) - val lRange = Range(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) - val rRange = Range(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) + val lInterval = ValueInterval(leftKeyStats.min, leftKeyStats.max, leftKey.dataType) + val rInterval = ValueInterval(rightKeyStats.min, rightKeyStats.max, rightKey.dataType) // When we reach here, join selectivity is not zero, so each pair of join keys should be // intersected. - assert(Range.isIntersected(lRange, rRange)) + assert(ValueInterval.isIntersected(lInterval, rInterval)) // Update intersected column stats assert(leftKey.dataType.sameType(rightKey.dataType)) val newNdv = leftKeyStats.distinctCount.min(rightKeyStats.distinctCount) - val (newMin, newMax) = Range.intersect(lRange, rRange, leftKey.dataType) + val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType) val newMaxLen = math.min(leftKeyStats.maxLen, rightKeyStats.maxLen) val newAvgLen = (leftKeyStats.avgLen + rightKeyStats.avgLen) / 2 val newStats = ColumnStat(newNdv, newMin, newMax, 0, newAvgLen, newMaxLen) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala similarity index 65% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala index 4ac5ba5689f82de2b775f6b93c51c93e2b954933..0caaf796a3b680a6c426493dad803df24da71e53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala @@ -22,12 +22,12 @@ import org.apache.spark.sql.types._ /** Value range of a column. */ -trait Range { +trait ValueInterval { def contains(l: Literal): Boolean } -/** For simplicity we use decimal to unify operations of numeric ranges. */ -case class NumericRange(min: Decimal, max: Decimal) extends Range { +/** For simplicity we use decimal to unify operations of numeric intervals. */ +case class NumericValueInterval(min: Decimal, max: Decimal) extends ValueInterval { override def contains(l: Literal): Boolean = { val lit = EstimationUtils.toDecimal(l.value, l.dataType) min <= lit && max >= lit @@ -38,46 +38,49 @@ case class NumericRange(min: Decimal, max: Decimal) extends Range { * This version of Spark does not have min/max for binary/string types, we define their default * behaviors by this class. */ -class DefaultRange extends Range { +class DefaultValueInterval extends ValueInterval { override def contains(l: Literal): Boolean = true } /** This is for columns with only null values. */ -class NullRange extends Range { +class NullValueInterval extends ValueInterval { override def contains(l: Literal): Boolean = false } -object Range { - def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { - case StringType | BinaryType => new DefaultRange() - case _ if min.isEmpty || max.isEmpty => new NullRange() +object ValueInterval { + def apply( + min: Option[Any], + max: Option[Any], + dataType: DataType): ValueInterval = dataType match { + case StringType | BinaryType => new DefaultValueInterval() + case _ if min.isEmpty || max.isEmpty => new NullValueInterval() case _ => - NumericRange( + NumericValueInterval( min = EstimationUtils.toDecimal(min.get, dataType), max = EstimationUtils.toDecimal(max.get, dataType)) } - def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match { - case (_, _: DefaultRange) | (_: DefaultRange, _) => - // The DefaultRange represents string/binary types which do not have max/min stats, + def isIntersected(r1: ValueInterval, r2: ValueInterval): Boolean = (r1, r2) match { + case (_, _: DefaultValueInterval) | (_: DefaultValueInterval, _) => + // The DefaultValueInterval represents string/binary types which do not have max/min stats, // we assume they are intersected to be conservative on estimation true - case (_, _: NullRange) | (_: NullRange, _) => + case (_, _: NullValueInterval) | (_: NullValueInterval, _) => false - case (n1: NumericRange, n2: NumericRange) => + case (n1: NumericValueInterval, n2: NumericValueInterval) => n1.min.compareTo(n2.max) <= 0 && n1.max.compareTo(n2.min) >= 0 } /** - * Intersected results of two ranges. This is only for two overlapped ranges. + * Intersected results of two intervals. This is only for two overlapped intervals. * The outputs are the intersected min/max values. */ - def intersect(r1: Range, r2: Range, dt: DataType): (Option[Any], Option[Any]) = { + def intersect(r1: ValueInterval, r2: ValueInterval, dt: DataType): (Option[Any], Option[Any]) = { (r1, r2) match { - case (_, _: DefaultRange) | (_: DefaultRange, _) => + case (_, _: DefaultValueInterval) | (_: DefaultValueInterval, _) => // binary/string types don't support intersecting. (None, None) - case (n1: NumericRange, n2: NumericRange) => + case (n1: NumericValueInterval, n2: NumericValueInterval) => // Choose the maximum of two min values, and the minimum of two max values. val newMin = if (n1.min <= n2.min) n2.min else n1.min val newMax = if (n1.max <= n2.max) n1.max else n2.max