From e1e05873fc75781b6dd3f7fadbfb57824f83054e Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh <viirya@appier.com>
Date: Wed, 5 Aug 2015 11:38:56 -0700
Subject: [PATCH] [SPARK-9403] [SQL] Add codegen support in In and InSet

This continues tarekauel's work in #7778.

Author: Liang-Chi Hsieh <viirya@appier.com>
Author: Tarek Auel <tarek.auel@googlemail.com>

Closes #7893 from viirya/codegen_in and squashes the following commits:

81ff97b [Liang-Chi Hsieh] For comments.
47761c6 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in
cf4bf41 [Liang-Chi Hsieh] For comments.
f532b3c [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in
446bbcd [Liang-Chi Hsieh] Fix bug.
b3d0ab4 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in
4610eff [Liang-Chi Hsieh] Relax the types of references and update optimizer test.
224f18e [Liang-Chi Hsieh] Beef up the test cases for In and InSet to include all primitive data types.
86dc8aa [Liang-Chi Hsieh] Only convert In to InSet when the number of items in set is more than the threshold.
b7ded7e [Tarek Auel] [SPARK-9403][SQL] codeGen in / inSet
---
 .../sql/catalyst/expressions/predicates.scala | 63 +++++++++++++++++--
 .../sql/catalyst/optimizer/Optimizer.scala    |  2 +-
 .../catalyst/expressions/PredicateSuite.scala | 37 ++++++++++-
 .../catalyst/optimizer/OptimizeInSuite.scala  | 14 ++++-
 .../datasources/DataSourceStrategy.scala      |  7 +++
 .../spark/sql/ColumnExpressionSuite.scala     |  6 ++
 6 files changed, 119 insertions(+), 10 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index b69bbabee7..68c832d719 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -17,6 +17,9 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -97,32 +100,80 @@ case class Not(child: Expression)
 /**
  * Evaluates to `true` if `list` contains `value`.
  */
-case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback {
+case class In(value: Expression, list: Seq[Expression]) extends Predicate
+    with ImplicitCastInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (list.exists(l => l.dataType != value.dataType)) {
+      TypeCheckResult.TypeCheckFailure(
+        "Arguments must be same type")
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
   override def children: Seq[Expression] = value +: list
 
-  override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
+  override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
   override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"
 
   override def eval(input: InternalRow): Any = {
     val evaluatedValue = value.eval(input)
     list.exists(e => e.eval(input) == evaluatedValue)
   }
-}
 
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val valueGen = value.gen(ctx)
+    val listGen = list.map(_.gen(ctx))
+    val listCode = listGen.map(x =>
+      s"""
+        if (!${ev.primitive}) {
+          ${x.code}
+          if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) {
+            ${ev.primitive} = true;
+          }
+        }
+       """).mkString("\n")
+    s"""
+      ${valueGen.code}
+      boolean ${ev.primitive} = false;
+      boolean ${ev.isNull} = false;
+      $listCode
+    """
+  }
+}
 
 /**
  * Optimized version of In clause, when all filter values of In clause are
  * static.
  */
-case class InSet(child: Expression, hset: Set[Any])
-  extends UnaryExpression with Predicate with CodegenFallback {
+case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate {
 
-  override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
+  override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
   override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}"
 
   override def eval(input: InternalRow): Any = {
     hset.contains(child.eval(input))
   }
+
+  def getHSet(): Set[Any] = hset
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val setName = classOf[Set[Any]].getName
+    val InSetName = classOf[InSet].getName
+    val childGen = child.gen(ctx)
+    ctx.references += this
+    val hsetTerm = ctx.freshName("hset")
+    ctx.addMutableState(setName, hsetTerm,
+      s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();")
+    s"""
+      ${childGen.code}
+      boolean ${ev.isNull} = false;
+      boolean ${ev.primitive} = $hsetTerm.contains(${childGen.primitive});
+     """
+  }
 }
 
 case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 29d706dcb3..4ab5ac2c61 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -393,7 +393,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
 object OptimizeIn extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case q: LogicalPlan => q transformExpressionsDown {
-      case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
+      case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) && list.size > 10 =>
         val hSet = list.map(e => e.eval(EmptyRow))
         InSet(v, HashSet() ++ hSet)
     }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index d7eb13c50b..7beef71845 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -21,7 +21,8 @@ import scala.collection.immutable.HashSet
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType}
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.types._
 
 
 class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -118,6 +119,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
     checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
     checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
+
+    val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
+      LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
+    primitiveTypes.map { t =>
+      val dataGen = RandomDataGenerator.forType(t, nullable = false).get
+      val inputData = Seq.fill(10) {
+        val value = dataGen.apply()
+        value match {
+          case d: Double if d.isNaN => 0.0d
+          case f: Float if f.isNaN => 0.0f
+          case _ => value
+        }
+      }
+      val input = inputData.map(Literal(_))
+      checkEvaluation(In(input(0), input.slice(1, 10)),
+        inputData.slice(1, 10).contains(inputData(0)))
+    }
   }
 
   test("INSET") {
@@ -134,6 +152,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(InSet(three, hS), false)
     checkEvaluation(InSet(three, nS), false)
     checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
+
+    val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
+      LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
+    primitiveTypes.map { t =>
+      val dataGen = RandomDataGenerator.forType(t, nullable = false).get
+      val inputData = Seq.fill(10) {
+        val value = dataGen.apply()
+        value match {
+          case d: Double if d.isNaN => 0.0d
+          case f: Float if f.isNaN => 0.0f
+          case _ => value
+        }
+      }
+      val input = inputData.map(Literal(_))
+      checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet),
+        inputData.slice(1, 10).contains(inputData(0)))
+    }
   }
 
   private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index 1d433275fe..6f7b5b9572 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -43,16 +43,26 @@ class OptimizeInSuite extends PlanTest {
 
   val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
 
-  test("OptimizedIn test: In clause optimized to InSet") {
+  test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") {
     val originalQuery =
       testRelation
         .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))
         .analyze
 
+    val optimized = Optimize.execute(originalQuery.analyze)
+    comparePlans(optimized, originalQuery)
+  }
+
+  test("OptimizedIn test: In clause optimized to InSet when more than 10 items") {
+    val originalQuery =
+      testRelation
+        .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_))))
+        .analyze
+
     val optimized = Optimize.execute(originalQuery.analyze)
     val correctAnswer =
       testRelation
-        .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2))
+        .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet))
         .analyze
 
     comparePlans(optimized, correctAnswer)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index a43bccbe69..e5dc676b87 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -366,6 +366,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
       case expressions.InSet(a: Attribute, set) =>
         Some(sources.In(a.name, set.toArray))
 
+      // Because we only convert In to InSet in Optimizer when there are more than certain
+      // items. So it is possible we still get an In expression here that needs to be pushed
+      // down.
+      case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
+        val hSet = list.map(e => e.eval(EmptyRow))
+        Some(sources.In(a.name, hSet.toArray))
+
       case expressions.IsNull(a: Attribute) =>
         Some(sources.IsNull(a.name))
       case expressions.IsNotNull(a: Attribute) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 35ca0b4c7c..b351380373 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -357,6 +357,12 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
       df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x"))
     checkAnswer(df.filter($"b".in("z", "y")),
       df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))
+
+    val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
+
+    intercept[AnalysisException] {
+      df2.filter($"a".in($"b"))
+    }
   }
 
   val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize(
-- 
GitLab