diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index b69bbabee7e810d212097b4f7c356ee8bb597ba1..68c832d7194d4cb1d142725638ff3ef37a1c34af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -17,6 +17,9 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -97,32 +100,80 @@ case class Not(child: Expression)
 /**
  * Evaluates to `true` if `list` contains `value`.
  */
-case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback {
+case class In(value: Expression, list: Seq[Expression]) extends Predicate
+    with ImplicitCastInputTypes {
+
+  override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (list.exists(l => l.dataType != value.dataType)) {
+      TypeCheckResult.TypeCheckFailure(
+        "Arguments must be same type")
+    } else {
+      TypeCheckResult.TypeCheckSuccess
+    }
+  }
+
   override def children: Seq[Expression] = value +: list
 
-  override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
+  override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
   override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"
 
   override def eval(input: InternalRow): Any = {
     val evaluatedValue = value.eval(input)
     list.exists(e => e.eval(input) == evaluatedValue)
   }
-}
 
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val valueGen = value.gen(ctx)
+    val listGen = list.map(_.gen(ctx))
+    val listCode = listGen.map(x =>
+      s"""
+        if (!${ev.primitive}) {
+          ${x.code}
+          if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) {
+            ${ev.primitive} = true;
+          }
+        }
+       """).mkString("\n")
+    s"""
+      ${valueGen.code}
+      boolean ${ev.primitive} = false;
+      boolean ${ev.isNull} = false;
+      $listCode
+    """
+  }
+}
 
 /**
  * Optimized version of In clause, when all filter values of In clause are
  * static.
  */
-case class InSet(child: Expression, hset: Set[Any])
-  extends UnaryExpression with Predicate with CodegenFallback {
+case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate {
 
-  override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
+  override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
   override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}"
 
   override def eval(input: InternalRow): Any = {
     hset.contains(child.eval(input))
   }
+
+  def getHSet(): Set[Any] = hset
+
+  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+    val setName = classOf[Set[Any]].getName
+    val InSetName = classOf[InSet].getName
+    val childGen = child.gen(ctx)
+    ctx.references += this
+    val hsetTerm = ctx.freshName("hset")
+    ctx.addMutableState(setName, hsetTerm,
+      s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();")
+    s"""
+      ${childGen.code}
+      boolean ${ev.isNull} = false;
+      boolean ${ev.primitive} = $hsetTerm.contains(${childGen.primitive});
+     """
+  }
 }
 
 case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 29d706dcb39a7db75cb8fec873effcbf42d12f68..4ab5ac2c61e3ce3d316326dcc93d5124c6d80a63 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -393,7 +393,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
 object OptimizeIn extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case q: LogicalPlan => q transformExpressionsDown {
-      case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
+      case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) && list.size > 10 =>
         val hSet = list.map(e => e.eval(EmptyRow))
         InSet(v, HashSet() ++ hSet)
     }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index d7eb13c50b134691b226d68e8f7bd8ed5cd8bcf0..7beef71845e434cc883675f95384d6c7ddd4a5bb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -21,7 +21,8 @@ import scala.collection.immutable.HashSet
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType}
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.types._
 
 
 class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -118,6 +119,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
     checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
     checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
+
+    val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
+      LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
+    primitiveTypes.map { t =>
+      val dataGen = RandomDataGenerator.forType(t, nullable = false).get
+      val inputData = Seq.fill(10) {
+        val value = dataGen.apply()
+        value match {
+          case d: Double if d.isNaN => 0.0d
+          case f: Float if f.isNaN => 0.0f
+          case _ => value
+        }
+      }
+      val input = inputData.map(Literal(_))
+      checkEvaluation(In(input(0), input.slice(1, 10)),
+        inputData.slice(1, 10).contains(inputData(0)))
+    }
   }
 
   test("INSET") {
@@ -134,6 +152,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(InSet(three, hS), false)
     checkEvaluation(InSet(three, nS), false)
     checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
+
+    val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
+      LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
+    primitiveTypes.map { t =>
+      val dataGen = RandomDataGenerator.forType(t, nullable = false).get
+      val inputData = Seq.fill(10) {
+        val value = dataGen.apply()
+        value match {
+          case d: Double if d.isNaN => 0.0d
+          case f: Float if f.isNaN => 0.0f
+          case _ => value
+        }
+      }
+      val input = inputData.map(Literal(_))
+      checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet),
+        inputData.slice(1, 10).contains(inputData(0)))
+    }
   }
 
   private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index 1d433275fed2eaa9cdb35d94daafcf36cf28c269..6f7b5b9572e22d335caf2b88a515a532def42e04 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -43,16 +43,26 @@ class OptimizeInSuite extends PlanTest {
 
   val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
 
-  test("OptimizedIn test: In clause optimized to InSet") {
+  test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") {
     val originalQuery =
       testRelation
         .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))
         .analyze
 
+    val optimized = Optimize.execute(originalQuery.analyze)
+    comparePlans(optimized, originalQuery)
+  }
+
+  test("OptimizedIn test: In clause optimized to InSet when more than 10 items") {
+    val originalQuery =
+      testRelation
+        .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_))))
+        .analyze
+
     val optimized = Optimize.execute(originalQuery.analyze)
     val correctAnswer =
       testRelation
-        .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2))
+        .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet))
         .analyze
 
     comparePlans(optimized, correctAnswer)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index a43bccbe6927c55020eacd6966e0af96ff72a276..e5dc676b878416c286f6f891b3adaa4e577e6929 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -366,6 +366,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
       case expressions.InSet(a: Attribute, set) =>
         Some(sources.In(a.name, set.toArray))
 
+      // Because we only convert In to InSet in Optimizer when there are more than certain
+      // items. So it is possible we still get an In expression here that needs to be pushed
+      // down.
+      case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
+        val hSet = list.map(e => e.eval(EmptyRow))
+        Some(sources.In(a.name, hSet.toArray))
+
       case expressions.IsNull(a: Attribute) =>
         Some(sources.IsNull(a.name))
       case expressions.IsNotNull(a: Attribute) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 35ca0b4c7cc214df42f8e80ef222e008cc1636f5..b35138037325991a960c87fe604f4b3e341e093f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -357,6 +357,12 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
       df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x"))
     checkAnswer(df.filter($"b".in("z", "y")),
       df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))
+
+    val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
+
+    intercept[AnalysisException] {
+      df2.filter($"a".in($"b"))
+    }
   }
 
   val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize(