Skip to content
Snippets Groups Projects
Commit 3059291e authored by wangfei's avatar wangfei Committed by Reynold Xin
Browse files

[SQL][Minor] make StringComparison extends ExpectsInputTypes

make StringComparison extends ExpectsInputTypes and added expectedChildTypes, so do not need override expectedChildTypes in each subclass

Author: wangfei <wangfei1@huawei.com>

Closes #5905 from scwf/ExpectsInputTypes and squashes the following commits:

b374ddf [wangfei] make stringcomparison extends ExpectsInputTypes
parent fec7b29f
No related branches found
No related tags found
No related merge requests found
...@@ -154,7 +154,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE ...@@ -154,7 +154,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE
} }
/** A base trait for functions that compare two strings, returning a boolean. */ /** A base trait for functions that compare two strings, returning a boolean. */
trait StringComparison { trait StringComparison extends ExpectsInputTypes {
self: BinaryExpression => self: BinaryExpression =>
def compare(l: UTF8String, r: UTF8String): Boolean def compare(l: UTF8String, r: UTF8String): Boolean
...@@ -163,6 +163,8 @@ trait StringComparison { ...@@ -163,6 +163,8 @@ trait StringComparison {
override def nullable: Boolean = left.nullable || right.nullable override def nullable: Boolean = left.nullable || right.nullable
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
override def eval(input: Row): Any = { override def eval(input: Row): Any = {
val leftEval = left.eval(input) val leftEval = left.eval(input)
if(leftEval == null) { if(leftEval == null) {
...@@ -183,27 +185,24 @@ trait StringComparison { ...@@ -183,27 +185,24 @@ trait StringComparison {
* A function that returns true if the string `left` contains the string `right`. * A function that returns true if the string `left` contains the string `right`.
*/ */
case class Contains(left: Expression, right: Expression) case class Contains(left: Expression, right: Expression)
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes { extends BinaryExpression with Predicate with StringComparison {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
} }
/** /**
* A function that returns true if the string `left` starts with the string `right`. * A function that returns true if the string `left` starts with the string `right`.
*/ */
case class StartsWith(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression)
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes { extends BinaryExpression with Predicate with StringComparison {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
} }
/** /**
* A function that returns true if the string `left` ends with the string `right`. * A function that returns true if the string `left` ends with the string `right`.
*/ */
case class EndsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression)
extends BinaryExpression with Predicate with StringComparison with ExpectsInputTypes { extends BinaryExpression with Predicate with StringComparison {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
} }
/** /**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment