Skip to content
Snippets Groups Projects
Commit 678c4da0 authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-7266] Add ExpectsInputTypes to expressions when possible.

This should gives us better analysis time error messages (rather than runtime) and automatic type casting.

Author: Reynold Xin <rxin@databricks.com>

Closes #5796 from rxin/expected-input-types and squashes the following commits:

c900760 [Reynold Xin] [SPARK-7266] Add ExpectsInputTypes to expressions when possible.
parent 80554111
No related branches found
No related tags found
No related merge requests found
......@@ -239,37 +239,43 @@ trait HiveTypeCoercion {
a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
// we should cast all timestamp/date/string compare into string compare
case p: BinaryPredicate if p.left.dataType == StringType
&& p.right.dataType == DateType =>
case p: BinaryComparison if p.left.dataType == StringType &&
p.right.dataType == DateType =>
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == DateType
&& p.right.dataType == StringType =>
case p: BinaryComparison if p.left.dataType == DateType &&
p.right.dataType == StringType =>
p.makeCopy(Array(Cast(p.left, StringType), p.right))
case p: BinaryPredicate if p.left.dataType == StringType
&& p.right.dataType == TimestampType =>
case p: BinaryComparison if p.left.dataType == StringType &&
p.right.dataType == TimestampType =>
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == TimestampType
&& p.right.dataType == StringType =>
case p: BinaryComparison if p.left.dataType == TimestampType &&
p.right.dataType == StringType =>
p.makeCopy(Array(Cast(p.left, StringType), p.right))
case p: BinaryPredicate if p.left.dataType == TimestampType
&& p.right.dataType == DateType =>
case p: BinaryComparison if p.left.dataType == TimestampType &&
p.right.dataType == DateType =>
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == DateType
&& p.right.dataType == TimestampType =>
case p: BinaryComparison if p.left.dataType == DateType &&
p.right.dataType == TimestampType =>
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType =>
case p: BinaryComparison if p.left.dataType == StringType &&
p.right.dataType != StringType =>
p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType =>
case p: BinaryComparison if p.left.dataType != StringType &&
p.right.dataType == StringType =>
p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))
case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) =>
case i @ In(a, b) if a.dataType == DateType &&
b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
case i @ In(a, b) if a.dataType == TimestampType &&
b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) =>
case i @ In(a, b) if a.dataType == DateType &&
b.forall(_.dataType == TimestampType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) =>
case i @ In(a, b) if a.dataType == TimestampType &&
b.forall(_.dataType == DateType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case Sum(e) if e.dataType == StringType =>
......@@ -420,19 +426,19 @@ trait HiveTypeCoercion {
)
case LessThan(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
......@@ -481,8 +487,8 @@ trait HiveTypeCoercion {
// No need to change the EqualNullSafe operators, too
case e: EqualNullSafe => e
// Otherwise turn them to Byte types so that there exists and ordering.
case p: BinaryComparison
if p.left.dataType == BooleanType && p.right.dataType == BooleanType =>
case p: BinaryComparison if p.left.dataType == BooleanType &&
p.right.dataType == BooleanType =>
p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType)))
}
}
......@@ -564,10 +570,6 @@ trait HiveTypeCoercion {
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
// Compatible with Hive
case Substring(e, start, len) if e.dataType != StringType =>
Substring(Cast(e, StringType), start, len)
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
......
......@@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
......@@ -86,6 +85,8 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def foldable: Boolean = left.foldable && right.foldable
override def nullable: Boolean = left.nullable || right.nullable
override def toString: String = s"($left $symbol $right)"
}
......
......@@ -74,14 +74,12 @@ abstract class BinaryArithmetic extends BinaryExpression {
type EvaluatedType = Any
def nullable: Boolean = left.nullable || right.nullable
override lazy val resolved =
left.resolved && right.resolved &&
left.dataType == right.dataType &&
!DecimalType.isFixed(left.dataType)
def dataType: DataType = {
override def dataType: DataType = {
if (!resolved) {
throw new UnresolvedException(this,
s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")
......
......@@ -70,16 +70,14 @@ trait PredicateHelper {
expr.references.subsetOf(plan.outputSet)
}
abstract class BinaryPredicate extends BinaryExpression with Predicate {
self: Product =>
override def nullable: Boolean = left.nullable || right.nullable
}
case class Not(child: Expression) extends UnaryExpression with Predicate {
case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes {
override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable
override def toString: String = s"NOT $child"
override def expectedChildTypes: Seq[DataType] = Seq(BooleanType)
override def eval(input: Row): Any = {
child.eval(input) match {
case null => null
......@@ -120,7 +118,11 @@ case class InSet(value: Expression, hset: Set[Any])
}
}
case class And(left: Expression, right: Expression) extends BinaryPredicate {
case class And(left: Expression, right: Expression)
extends BinaryExpression with Predicate with ExpectsInputTypes {
override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
override def symbol: String = "&&"
override def eval(input: Row): Any = {
......@@ -142,7 +144,11 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate {
}
}
case class Or(left: Expression, right: Expression) extends BinaryPredicate {
case class Or(left: Expression, right: Expression)
extends BinaryExpression with Predicate with ExpectsInputTypes {
override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
override def symbol: String = "||"
override def eval(input: Row): Any = {
......@@ -164,7 +170,7 @@ case class Or(left: Expression, right: Expression) extends BinaryPredicate {
}
}
abstract class BinaryComparison extends BinaryPredicate {
abstract class BinaryComparison extends BinaryExpression with Predicate {
self: Product =>
}
......
......@@ -22,7 +22,7 @@ import java.util.regex.Pattern
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.types._
trait StringRegexExpression {
trait StringRegexExpression extends ExpectsInputTypes {
self: BinaryExpression =>
type EvaluatedType = Any
......@@ -32,6 +32,7 @@ trait StringRegexExpression {
override def nullable: Boolean = left.nullable || right.nullable
override def dataType: DataType = BooleanType
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
// try cache the pattern for Literal
private lazy val cache: Pattern = right match {
......@@ -57,11 +58,11 @@ trait StringRegexExpression {
if(r == null) {
null
} else {
val regex = pattern(r.asInstanceOf[UTF8String].toString)
val regex = pattern(r.asInstanceOf[UTF8String].toString())
if(regex == null) {
null
} else {
matches(regex, l.asInstanceOf[UTF8String].toString)
matches(regex, l.asInstanceOf[UTF8String].toString())
}
}
}
......@@ -110,7 +111,7 @@ case class RLike(left: Expression, right: Expression)
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
}
trait CaseConversionExpression {
trait CaseConversionExpression extends ExpectsInputTypes {
self: UnaryExpression =>
type EvaluatedType = Any
......@@ -118,8 +119,9 @@ trait CaseConversionExpression {
def convert(v: UTF8String): UTF8String
override def foldable: Boolean = child.foldable
def nullable: Boolean = child.nullable
def dataType: DataType = StringType
override def nullable: Boolean = child.nullable
override def dataType: DataType = StringType
override def expectedChildTypes: Seq[DataType] = Seq(StringType)
override def eval(input: Row): Any = {
val evaluated = child.eval(input)
......@@ -136,7 +138,7 @@ trait CaseConversionExpression {
*/
case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {
override def convert(v: UTF8String): UTF8String = v.toUpperCase
override def convert(v: UTF8String): UTF8String = v.toUpperCase()
override def toString: String = s"Upper($child)"
}
......@@ -146,21 +148,21 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
*/
case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {
override def convert(v: UTF8String): UTF8String = v.toLowerCase
override def convert(v: UTF8String): UTF8String = v.toLowerCase()
override def toString: String = s"Lower($child)"
}
/** A base trait for functions that compare two strings, returning a boolean. */
trait StringComparison {
self: BinaryPredicate =>
self: BinaryExpression =>
def compare(l: UTF8String, r: UTF8String): Boolean
override type EvaluatedType = Any
override def nullable: Boolean = left.nullable || right.nullable
def compare(l: UTF8String, r: UTF8String): Boolean
override def eval(input: Row): Any = {
val leftEval = left.eval(input)
if(leftEval == null) {
......@@ -181,31 +183,35 @@ trait StringComparison {
* A function that returns true if the string `left` contains the string `right`.
*/
case class Contains(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
}
/**
* A function that returns true if the string `left` starts with the string `right`.
*/
case class StartsWith(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
}
/**
* A function that returns true if the string `left` ends with the string `right`.
*/
case class EndsWith(left: Expression, right: Expression)
extends BinaryPredicate with StringComparison {
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
}
/**
* A function that takes a substring of its first argument starting at a given position.
* Defined for String and Binary types.
*/
case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression {
case class Substring(str: Expression, pos: Expression, len: Expression)
extends Expression with ExpectsInputTypes {
type EvaluatedType = Any
......@@ -219,6 +225,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
if (str.dataType == BinaryType) str.dataType else StringType
}
override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
override def children: Seq[Expression] = str :: pos :: len :: Nil
@inline
......@@ -258,7 +266,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
val (st, end) = slicePos(start, length, () => ba.length)
ba.slice(st, end)
case s: UTF8String =>
val (st, end) = slicePos(start, length, () => s.length)
val (st, end) = slicePos(start, length, () => s.length())
s.slice(st, end)
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment