From 678c4da0fa1bbfb6b5a0d3aced7aefa1bbbc193c Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@databricks.com>
Date: Mon, 4 May 2015 18:03:07 -0700
Subject: [PATCH] [SPARK-7266] Add ExpectsInputTypes to expressions when
 possible.

This should gives us better analysis time error messages (rather than runtime) and automatic type casting.

Author: Reynold Xin <rxin@databricks.com>

Closes #5796 from rxin/expected-input-types and squashes the following commits:

c900760 [Reynold Xin] [SPARK-7266] Add ExpectsInputTypes to expressions when possible.
---
 .../catalyst/analysis/HiveTypeCoercion.scala  | 58 ++++++++++---------
 .../sql/catalyst/expressions/Expression.scala |  3 +-
 .../sql/catalyst/expressions/arithmetic.scala |  4 +-
 .../sql/catalyst/expressions/predicates.scala | 22 ++++---
 .../expressions/stringOperations.scala        | 40 ++++++++-----
 5 files changed, 71 insertions(+), 56 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 73c9a1c7af..831fb4fe95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -239,37 +239,43 @@ trait HiveTypeCoercion {
         a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
 
       // we should cast all timestamp/date/string compare into string compare
-      case p: BinaryPredicate if p.left.dataType == StringType
-        && p.right.dataType == DateType =>
+      case p: BinaryComparison if p.left.dataType == StringType &&
+                                  p.right.dataType == DateType =>
         p.makeCopy(Array(p.left, Cast(p.right, StringType)))
-      case p: BinaryPredicate if p.left.dataType == DateType
-        && p.right.dataType == StringType =>
+      case p: BinaryComparison if p.left.dataType == DateType &&
+                                  p.right.dataType == StringType =>
         p.makeCopy(Array(Cast(p.left, StringType), p.right))
-      case p: BinaryPredicate if p.left.dataType == StringType
-        && p.right.dataType == TimestampType =>
+      case p: BinaryComparison if p.left.dataType == StringType &&
+                                  p.right.dataType == TimestampType =>
         p.makeCopy(Array(p.left, Cast(p.right, StringType)))
-      case p: BinaryPredicate if p.left.dataType == TimestampType
-        && p.right.dataType == StringType =>
+      case p: BinaryComparison if p.left.dataType == TimestampType &&
+                                  p.right.dataType == StringType =>
         p.makeCopy(Array(Cast(p.left, StringType), p.right))
-      case p: BinaryPredicate if p.left.dataType == TimestampType
-        && p.right.dataType == DateType =>
+      case p: BinaryComparison if p.left.dataType == TimestampType &&
+                                  p.right.dataType == DateType =>
         p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
-      case p: BinaryPredicate if p.left.dataType == DateType
-        && p.right.dataType == TimestampType =>
+      case p: BinaryComparison if p.left.dataType == DateType &&
+                                  p.right.dataType == TimestampType =>
         p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
 
-      case p: BinaryPredicate if p.left.dataType == StringType && p.right.dataType != StringType =>
+      case p: BinaryComparison if p.left.dataType == StringType &&
+                                  p.right.dataType != StringType =>
         p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
-      case p: BinaryPredicate if p.left.dataType != StringType && p.right.dataType == StringType =>
+      case p: BinaryComparison if p.left.dataType != StringType &&
+                                  p.right.dataType == StringType =>
         p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))
 
-      case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == StringType) =>
+      case i @ In(a, b) if a.dataType == DateType &&
+                           b.forall(_.dataType == StringType) =>
         i.makeCopy(Array(Cast(a, StringType), b))
-      case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == StringType) =>
+      case i @ In(a, b) if a.dataType == TimestampType &&
+                           b.forall(_.dataType == StringType) =>
         i.makeCopy(Array(Cast(a, StringType), b))
-      case i @ In(a, b) if a.dataType == DateType && b.forall(_.dataType == TimestampType) =>
+      case i @ In(a, b) if a.dataType == DateType &&
+                           b.forall(_.dataType == TimestampType) =>
         i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
-      case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) =>
+      case i @ In(a, b) if a.dataType == TimestampType &&
+                           b.forall(_.dataType == DateType) =>
         i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
 
       case Sum(e) if e.dataType == StringType =>
@@ -420,19 +426,19 @@ trait HiveTypeCoercion {
           )
 
         case LessThan(e1 @ DecimalType.Expression(p1, s1),
-        e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+                      e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
           LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
 
         case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
-        e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+                             e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
           LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
 
         case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
-        e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+                         e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
           GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
 
         case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
-        e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
+                                e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
           GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
 
         // Promote integers inside a binary expression with fixed-precision decimals to decimals,
@@ -481,8 +487,8 @@ trait HiveTypeCoercion {
       // No need to change the EqualNullSafe operators, too
       case e: EqualNullSafe => e
       // Otherwise turn them to Byte types so that there exists and ordering.
-      case p: BinaryComparison
-          if p.left.dataType == BooleanType && p.right.dataType == BooleanType =>
+      case p: BinaryComparison if p.left.dataType == BooleanType &&
+                                  p.right.dataType == BooleanType =>
         p.makeCopy(Array(Cast(p.left, ByteType), Cast(p.right, ByteType)))
     }
   }
@@ -564,10 +570,6 @@ trait HiveTypeCoercion {
       case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
       case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
 
-      // Compatible with Hive
-      case Substring(e, start, len) if e.dataType != StringType =>
-        Substring(Cast(e, StringType), start, len)
-
       // Coalesce should return the first non-null value, which could be any column
       // from the list. So we need to make sure the return type is deterministic and
       // compatible with every child column.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 1d71c1b4b0..4fd1bc4dd6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
 import org.apache.spark.sql.catalyst.trees
 import org.apache.spark.sql.catalyst.trees.TreeNode
 import org.apache.spark.sql.types._
@@ -86,6 +85,8 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
 
   override def foldable: Boolean = left.foldable && right.foldable
 
+  override def nullable: Boolean = left.nullable || right.nullable
+
   override def toString: String = s"($left $symbol $right)"
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 140ccd8d37..c7a37ad966 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -74,14 +74,12 @@ abstract class BinaryArithmetic extends BinaryExpression {
 
   type EvaluatedType = Any
 
-  def nullable: Boolean = left.nullable || right.nullable
-
   override lazy val resolved =
     left.resolved && right.resolved &&
     left.dataType == right.dataType &&
     !DecimalType.isFixed(left.dataType)
 
-  def dataType: DataType = {
+  override def dataType: DataType = {
     if (!resolved) {
       throw new UnresolvedException(this,
         s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 9cb00cb273..26c38c56c0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -70,16 +70,14 @@ trait PredicateHelper {
     expr.references.subsetOf(plan.outputSet)
 }
 
-abstract class BinaryPredicate extends BinaryExpression with Predicate {
-  self: Product =>
-  override def nullable: Boolean = left.nullable || right.nullable
-}
 
-case class Not(child: Expression) extends UnaryExpression with Predicate {
+case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes {
   override def foldable: Boolean = child.foldable
   override def nullable: Boolean = child.nullable
   override def toString: String = s"NOT $child"
 
+  override def expectedChildTypes: Seq[DataType] = Seq(BooleanType)
+
   override def eval(input: Row): Any = {
     child.eval(input) match {
       case null => null
@@ -120,7 +118,11 @@ case class InSet(value: Expression, hset: Set[Any])
   }
 }
 
-case class And(left: Expression, right: Expression) extends BinaryPredicate {
+case class And(left: Expression, right: Expression)
+  extends BinaryExpression with Predicate with ExpectsInputTypes {
+
+  override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
+
   override def symbol: String = "&&"
 
   override def eval(input: Row): Any = {
@@ -142,7 +144,11 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate {
   }
 }
 
-case class Or(left: Expression, right: Expression) extends BinaryPredicate {
+case class Or(left: Expression, right: Expression)
+  extends BinaryExpression with Predicate with ExpectsInputTypes {
+
+  override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
+
   override def symbol: String = "||"
 
   override def eval(input: Row): Any = {
@@ -164,7 +170,7 @@ case class Or(left: Expression, right: Expression) extends BinaryPredicate {
   }
 }
 
-abstract class BinaryComparison extends BinaryPredicate {
+abstract class BinaryComparison extends BinaryExpression with Predicate {
   self: Product =>
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index d597bf7ce7..d6f23df30f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -22,7 +22,7 @@ import java.util.regex.Pattern
 import org.apache.spark.sql.catalyst.analysis.UnresolvedException
 import org.apache.spark.sql.types._
 
-trait StringRegexExpression {
+trait StringRegexExpression extends ExpectsInputTypes {
   self: BinaryExpression =>
 
   type EvaluatedType = Any
@@ -32,6 +32,7 @@ trait StringRegexExpression {
 
   override def nullable: Boolean = left.nullable || right.nullable
   override def dataType: DataType = BooleanType
+  override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
 
   // try cache the pattern for Literal
   private lazy val cache: Pattern = right match {
@@ -57,11 +58,11 @@ trait StringRegexExpression {
       if(r == null) {
         null
       } else {
-        val regex = pattern(r.asInstanceOf[UTF8String].toString)
+        val regex = pattern(r.asInstanceOf[UTF8String].toString())
         if(regex == null) {
           null
         } else {
-          matches(regex, l.asInstanceOf[UTF8String].toString)
+          matches(regex, l.asInstanceOf[UTF8String].toString())
         }
       }
     }
@@ -110,7 +111,7 @@ case class RLike(left: Expression, right: Expression)
   override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
 }
 
-trait CaseConversionExpression {
+trait CaseConversionExpression extends ExpectsInputTypes {
   self: UnaryExpression =>
 
   type EvaluatedType = Any
@@ -118,8 +119,9 @@ trait CaseConversionExpression {
   def convert(v: UTF8String): UTF8String
 
   override def foldable: Boolean = child.foldable
-  def nullable: Boolean = child.nullable
-  def dataType: DataType = StringType
+  override def nullable: Boolean = child.nullable
+  override def dataType: DataType = StringType
+  override def expectedChildTypes: Seq[DataType] = Seq(StringType)
 
   override def eval(input: Row): Any = {
     val evaluated = child.eval(input)
@@ -136,7 +138,7 @@ trait CaseConversionExpression {
  */
 case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {
   
-  override def convert(v: UTF8String): UTF8String = v.toUpperCase
+  override def convert(v: UTF8String): UTF8String = v.toUpperCase()
 
   override def toString: String = s"Upper($child)"
 }
@@ -146,21 +148,21 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
  */
 case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {
   
-  override def convert(v: UTF8String): UTF8String = v.toLowerCase
+  override def convert(v: UTF8String): UTF8String = v.toLowerCase()
 
   override def toString: String = s"Lower($child)"
 }
 
 /** A base trait for functions that compare two strings, returning a boolean. */
 trait StringComparison {
-  self: BinaryPredicate =>
+  self: BinaryExpression =>
+
+  def compare(l: UTF8String, r: UTF8String): Boolean
 
   override type EvaluatedType = Any
 
   override def nullable: Boolean = left.nullable || right.nullable
 
-  def compare(l: UTF8String, r: UTF8String): Boolean
-
   override def eval(input: Row): Any = {
     val leftEval = left.eval(input)
     if(leftEval == null) {
@@ -181,31 +183,35 @@ trait StringComparison {
  * A function that returns true if the string `left` contains the string `right`.
  */
 case class Contains(left: Expression, right: Expression)
-    extends BinaryPredicate with StringComparison {
+    extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
   override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
+  override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
 }
 
 /**
  * A function that returns true if the string `left` starts with the string `right`.
  */
 case class StartsWith(left: Expression, right: Expression)
-    extends BinaryPredicate with StringComparison {
+    extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
   override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
+  override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
 }
 
 /**
  * A function that returns true if the string `left` ends with the string `right`.
  */
 case class EndsWith(left: Expression, right: Expression)
-    extends BinaryPredicate with StringComparison {
+    extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
   override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
+  override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
 }
 
 /**
  * A function that takes a substring of its first argument starting at a given position.
  * Defined for String and Binary types.
  */
-case class Substring(str: Expression, pos: Expression, len: Expression) extends Expression {
+case class Substring(str: Expression, pos: Expression, len: Expression)
+  extends Expression with ExpectsInputTypes {
   
   type EvaluatedType = Any
 
@@ -219,6 +225,8 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
     if (str.dataType == BinaryType) str.dataType else StringType
   }
 
+  override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
+
   override def children: Seq[Expression] = str :: pos :: len :: Nil
 
   @inline
@@ -258,7 +266,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) extends
           val (st, end) = slicePos(start, length, () => ba.length)
           ba.slice(st, end)
         case s: UTF8String =>
-          val (st, end) = slicePos(start, length, () => s.length)
+          val (st, end) = slicePos(start, length, () => s.length())
           s.slice(st, end)
       }
     }
-- 
GitLab