diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index f3fe58caa6fe2e4abe6a17a3a8d09e36b4cfba12..7bf10f199f1c77f4c9d9c45b2c0da99600797108 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -17,10 +17,11 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import scala.collection.immutable.TreeSet
+
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => BasePredicate}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
@@ -175,20 +176,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
                  |[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
                """.stripMargin)
           } else {
-            TypeCheckResult.TypeCheckSuccess
+            TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
           }
         }
       case _ =>
-        if (list.exists(l => l.dataType != value.dataType)) {
-          TypeCheckResult.TypeCheckFailure("Arguments must be same type")
+        val mismatchOpt = list.find(l => l.dataType != value.dataType)
+        if (mismatchOpt.isDefined) {
+          TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
+            s"${value.dataType} != ${mismatchOpt.get.dataType}")
         } else {
-          TypeCheckResult.TypeCheckSuccess
+          TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
         }
     }
   }
 
   override def children: Seq[Expression] = value +: list
   lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal])
+  private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType)
 
   override def nullable: Boolean = children.exists(_.nullable)
   override def foldable: Boolean = children.forall(_.foldable)
@@ -203,10 +207,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
       var hasNull = false
       list.foreach { e =>
         val v = e.eval(input)
-        if (v == evaluatedValue) {
-          return true
-        } else if (v == null) {
+        if (v == null) {
           hasNull = true
+        } else if (ordering.equiv(v, evaluatedValue)) {
+          return true
         }
       }
       if (hasNull) {
@@ -265,7 +269,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
   override def nullable: Boolean = child.nullable || hasNull
 
   protected override def nullSafeEval(value: Any): Any = {
-    if (hset.contains(value)) {
+    if (set.contains(value)) {
       true
     } else if (hasNull) {
       null
@@ -274,27 +278,40 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
     }
   }
 
-  def getHSet(): Set[Any] = hset
+  @transient private[this] lazy val set = child.dataType match {
+    case _: AtomicType => hset
+    case _: NullType => hset
+    case _ =>
+      // for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows
+      TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset
+  }
+
+  def getSet(): Set[Any] = set
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val setName = classOf[Set[Any]].getName
     val InSetName = classOf[InSet].getName
     val childGen = child.genCode(ctx)
     ctx.references += this
-    val hsetTerm = ctx.freshName("hset")
-    val hasNullTerm = ctx.freshName("hasNull")
-    ctx.addMutableState(setName, hsetTerm,
-      s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();")
-    ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);")
+    val setTerm = ctx.freshName("set")
+    val setNull = if (hasNull) {
+      s"""
+         |if (!${ev.value}) {
+         |  ${ev.isNull} = true;
+         |}
+       """.stripMargin
+    } else {
+      ""
+    }
+    ctx.addMutableState(setName, setTerm,
+      s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();")
     ev.copy(code = s"""
       ${childGen.code}
       boolean ${ev.isNull} = ${childGen.isNull};
       boolean ${ev.value} = false;
       if (!${ev.isNull}) {
-        ${ev.value} = $hsetTerm.contains(${childGen.value});
-        if (!${ev.value} && $hasNullTerm) {
-          ${ev.isNull} = true;
-        }
+        ${ev.value} = $setTerm.contains(${childGen.value});
+        $setNull
       }
      """)
   }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 6fe295c3dd936b39fdc0c4bb1ee37769e0a76d9a..ef510a95ef446e3b35f250f803ab587d7c60561d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -35,7 +35,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
     test(s"3VL $name") {
       truthTable.foreach {
         case (l, r, answer) =>
-          val expr = op(NonFoldableLiteral(l, BooleanType), NonFoldableLiteral(r, BooleanType))
+          val expr = op(NonFoldableLiteral.create(l, BooleanType),
+            NonFoldableLiteral.create(r, BooleanType))
           checkEvaluation(expr, answer)
       }
     }
@@ -72,7 +73,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
         (false, true) ::
         (null, null) :: Nil
     notTrueTable.foreach { case (v, answer) =>
-      checkEvaluation(Not(NonFoldableLiteral(v, BooleanType)), answer)
+      checkEvaluation(Not(NonFoldableLiteral.create(v, BooleanType)), answer)
     }
     checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType)
   }
@@ -120,22 +121,26 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
       (null, null, null) :: Nil)
 
   test("IN") {
-    checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq(Literal(1), Literal(2))), null)
-    checkEvaluation(In(NonFoldableLiteral(null, IntegerType),
-      Seq(NonFoldableLiteral(null, IntegerType))), null)
-    checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq.empty), null)
+    checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1),
+      Literal(2))), null)
+    checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType),
+      Seq(NonFoldableLiteral.create(null, IntegerType))), null)
+    checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null)
     checkEvaluation(In(Literal(1), Seq.empty), false)
-    checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral(null, IntegerType))), null)
-    checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), true)
-    checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), null)
+    checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null)
+    checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
+      true)
+    checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
+      null)
     checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
     checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
     checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false)
     checkEvaluation(
-      And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))),
+      And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1),
+        Literal(2)))),
       true)
 
-    val ns = NonFoldableLiteral(null, StringType)
+    val ns = NonFoldableLiteral.create(null, StringType)
     checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null)
     checkEvaluation(In(ns, Seq(ns)), null)
     checkEvaluation(In(Literal("a"), Seq(ns)), null)
@@ -155,7 +160,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
           case _ => value
         }
       }
-      val input = inputData.map(NonFoldableLiteral(_, t))
+      val input = inputData.map(NonFoldableLiteral.create(_, t))
       val expected = if (inputData(0) == null) {
         null
       } else if (inputData.slice(1, 10).contains(inputData(0))) {
@@ -279,7 +284,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
   test("BinaryComparison: null test") {
     // Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757
     val normalInt = Literal(-1)
-    val nullInt = NonFoldableLiteral(null, IntegerType)
+    val nullInt = NonFoldableLiteral.create(null, IntegerType)
 
     def nullTest(op: (Expression, Expression) => Expression): Unit = {
       checkEvaluation(op(normalInt, nullInt), null)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index 6a77580b29a2181203e7da8076fadca88045c29d..28bf7b6f84341c2dcfcede51831ad2e421649cb6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -169,7 +169,7 @@ class OptimizeInSuite extends PlanTest {
       val optimizedPlan = OptimizeIn(plan)
       optimizedPlan match {
         case Filter(cond, _)
-          if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 =>
+          if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getSet().size == 3 =>
         // pass
         case _ => fail("Unexpected result for OptimizedIn")
       }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 68f61cfab6d2faf605ccf020b15b03536277eea5..5171aaebc990780f22c989ac7564064d9884d510 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2616,4 +2616,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
     val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)"))
     assert(e.message.contains("Invalid number of arguments"))
   }
+
+  test("SPARK-21228: InSet incorrect handling of structs") {
+    withTempView("A") {
+      // reduce this from the default of 10 so the repro query text is not too long
+      withSQLConf((SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "3")) {
+        // a relation that has 1 column of struct type with values (1,1), ..., (9, 9)
+        spark.range(1, 10).selectExpr("named_struct('a', id, 'b', id) as a")
+          .createOrReplaceTempView("A")
+        val df = sql(
+          """
+            |SELECT * from
+            | (SELECT MIN(a) as minA FROM A) AA -- this Aggregate will return UnsafeRows
+            | -- the IN will become InSet with a Set of GenericInternalRows
+            | -- a GenericInternalRow is never equal to an UnsafeRow so the query would
+            | -- returns 0 results, which is incorrect
+            | WHERE minA IN (NAMED_STRUCT('a', 1L, 'b', 1L), NAMED_STRUCT('a', 2L, 'b', 2L),
+            |   NAMED_STRUCT('a', 3L, 'b', 3L))
+          """.stripMargin)
+        checkAnswer(df, Row(Row(1, 1)))
+      }
+    }
+  }
 }