Skip to content
Snippets Groups Projects
Commit 3d2134fc authored by Yijie Shen's avatar Yijie Shen Committed by Reynold Xin
Browse files

[SPARK-9055][SQL] WidenTypes should also support Intersect and Except

JIRA: https://issues.apache.org/jira/browse/SPARK-9055

cc rxin

Author: Yijie Shen <henry.yijieshen@gmail.com>

Closes #7491 from yijieshen/widen and squashes the following commits:

079fa52 [Yijie Shen] widenType support for intersect and expect
parent cdc36eef
No related branches found
No related tags found
No related merge requests found
...@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis ...@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import javax.annotation.Nullable import javax.annotation.Nullable
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
...@@ -168,52 +168,65 @@ object HiveTypeCoercion { ...@@ -168,52 +168,65 @@ object HiveTypeCoercion {
* - LongType to DoubleType * - LongType to DoubleType
*/ */
object WidenTypes extends Rule[LogicalPlan] { object WidenTypes extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// TODO: unions with fixed-precision decimals
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
val castedInput = left.output.zip(right.output).map {
// When a string is found on one side, make the other side a string too.
case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
(lhs, Alias(Cast(rhs, StringType), rhs.name)())
case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
(Alias(Cast(lhs, StringType), lhs.name)(), rhs)
case (lhs, rhs) if lhs.dataType != rhs.dataType => private[this] def widenOutputTypes(planName: String, left: LogicalPlan, right: LogicalPlan):
logDebug(s"Resolving mismatched union input ${lhs.dataType}, ${rhs.dataType}") (LogicalPlan, LogicalPlan) = {
findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
val newLeft = // TODO: with fixed-precision decimals
if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)() val castedInput = left.output.zip(right.output).map {
val newRight = // When a string is found on one side, make the other side a string too.
if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)() case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
(lhs, Alias(Cast(rhs, StringType), rhs.name)())
(newLeft, newRight) case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
}.getOrElse { (Alias(Cast(lhs, StringType), lhs.name)(), rhs)
// If there is no applicable conversion, leave expression unchanged.
(lhs, rhs) case (lhs, rhs) if lhs.dataType != rhs.dataType =>
} logDebug(s"Resolving mismatched $planName input ${lhs.dataType}, ${rhs.dataType}")
findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
val newLeft =
if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)()
val newRight =
if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)()
(newLeft, newRight)
}.getOrElse {
// If there is no applicable conversion, leave expression unchanged.
(lhs, rhs)
}
case other => other case other => other
} }
val (castedLeft, castedRight) = castedInput.unzip val (castedLeft, castedRight) = castedInput.unzip
val newLeft = val newLeft =
if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) {
logDebug(s"Widening numeric types in union $castedLeft ${left.output}") logDebug(s"Widening numeric types in $planName $castedLeft ${left.output}")
Project(castedLeft, left) Project(castedLeft, left)
} else { } else {
left left
} }
val newRight = val newRight =
if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { if (castedRight.map(_.dataType) != right.output.map(_.dataType)) {
logDebug(s"Widening numeric types in union $castedRight ${right.output}") logDebug(s"Widening numeric types in $planName $castedRight ${right.output}")
Project(castedRight, right) Project(castedRight, right)
} else { } else {
right right
} }
(newLeft, newRight)
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right)
Union(newLeft, newRight) Union(newLeft, newRight)
case e @ Except(left, right) if e.childrenResolved && !e.resolved =>
val (newLeft, newRight) = widenOutputTypes(e.nodeName, left, right)
Except(newLeft, newRight)
case i @ Intersect(left, right) if i.childrenResolved && !i.resolved =>
val (newLeft, newRight) = widenOutputTypes(i.nodeName, left, right)
Intersect(newLeft, newRight)
} }
} }
......
...@@ -141,6 +141,10 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { ...@@ -141,6 +141,10 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
override def output: Seq[Attribute] = left.output override def output: Seq[Attribute] = left.output
override lazy val resolved: Boolean =
childrenResolved &&
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
} }
case class InsertIntoTable( case class InsertIntoTable(
...@@ -437,4 +441,8 @@ case object OneRowRelation extends LeafNode { ...@@ -437,4 +441,8 @@ case object OneRowRelation extends LeafNode {
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
override def output: Seq[Attribute] = left.output override def output: Seq[Attribute] = left.output
override lazy val resolved: Boolean =
childrenResolved &&
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
} }
...@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis ...@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LocalRelation, Project} import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
...@@ -305,6 +305,38 @@ class HiveTypeCoercionSuite extends PlanTest { ...@@ -305,6 +305,38 @@ class HiveTypeCoercionSuite extends PlanTest {
) )
} }
test("WidenTypes for union except and intersect") {
def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
logical.output.zip(expectTypes).foreach { case (attr, dt) =>
assert(attr.dataType === dt)
}
}
val left = LocalRelation(
AttributeReference("i", IntegerType)(),
AttributeReference("u", DecimalType.Unlimited)(),
AttributeReference("b", ByteType)(),
AttributeReference("d", DoubleType)())
val right = LocalRelation(
AttributeReference("s", StringType)(),
AttributeReference("d", DecimalType(2, 1))(),
AttributeReference("f", FloatType)(),
AttributeReference("l", LongType)())
val wt = HiveTypeCoercion.WidenTypes
val expectedTypes = Seq(StringType, DecimalType.Unlimited, FloatType, DoubleType)
val r1 = wt(Union(left, right)).asInstanceOf[Union]
val r2 = wt(Except(left, right)).asInstanceOf[Except]
val r3 = wt(Intersect(left, right)).asInstanceOf[Intersect]
checkOutput(r1.left, expectedTypes)
checkOutput(r1.right, expectedTypes)
checkOutput(r2.left, expectedTypes)
checkOutput(r2.right, expectedTypes)
checkOutput(r3.left, expectedTypes)
checkOutput(r3.right, expectedTypes)
}
/** /**
* There are rules that need to not fire before child expressions get resolved. * There are rules that need to not fire before child expressions get resolved.
* We use this test to make sure those rules do not fire early. * We use this test to make sure those rules do not fire early.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment