diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index d6f23df30ffb418bc0ebef8642c1ab3423280579..7683e0990ce80fb6aadb4198f5e405de8d982085 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -154,7 +154,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE
 }
 
 /** A base trait for functions that compare two strings, returning a boolean. */
-trait StringComparison {
+trait StringComparison extends ExpectsInputTypes {
   self: BinaryExpression =>
 
   def compare(l: UTF8String, r: UTF8String): Boolean
@@ -163,6 +163,8 @@ trait StringComparison {
 
   override def nullable: Boolean = left.nullable || right.nullable
 
+  override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
+
   override def eval(input: Row): Any = {
     val leftEval = left.eval(input)
     if(leftEval == null) {
@@ -183,27 +185,24 @@ trait StringComparison {
  * A function that returns true if the string `left` contains the string `right`.
  */
 case class Contains(left: Expression, right: Expression)
-    extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
+    extends BinaryExpression with Predicate with StringComparison {
   override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
-  override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
 }
 
 /**
  * A function that returns true if the string `left` starts with the string `right`.
  */
 case class StartsWith(left: Expression, right: Expression)
-    extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
+    extends BinaryExpression with Predicate with StringComparison {
   override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
-  override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
 }
 
 /**
  * A function that returns true if the string `left` ends with the string `right`.
  */
 case class EndsWith(left: Expression, right: Expression)
-    extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes {
+    extends BinaryExpression with Predicate with StringComparison {
   override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
-  override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
 }
 
 /**