diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
index 438742565c51d58aeee657d7915f6bece8f8cc74..bf1bc5dffba78e2b1fd3ca805979d61d720574d3 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -23,6 +23,7 @@ import com.google.common.primitives.UnsignedBytes;
 
 import org.apache.spark.annotation.Private;
 import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.util.Utils;
 
 @Private
 public class PrefixComparators {
@@ -82,7 +83,7 @@ public class PrefixComparators {
     public int compare(long aPrefix, long bPrefix) {
       float a = Float.intBitsToFloat((int) aPrefix);
       float b = Float.intBitsToFloat((int) bPrefix);
-      return (a < b) ? -1 : (a > b) ? 1 : 0;
+      return Utils.nanSafeCompareFloats(a, b);
     }
 
     public long computePrefix(float value) {
@@ -97,7 +98,7 @@ public class PrefixComparators {
     public int compare(long aPrefix, long bPrefix) {
       double a = Double.longBitsToDouble(aPrefix);
       double b = Double.longBitsToDouble(bPrefix);
-      return (a < b) ? -1 : (a > b) ? 1 : 0;
+      return Utils.nanSafeCompareDoubles(a, b);
     }
 
     public long computePrefix(double value) {
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index e6374f17d858f2b152d087c0dfc010b4289094db..c5816949cd360ce8907f0b25c6f94dfd45fc6f73 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1586,6 +1586,34 @@ private[spark] object Utils extends Logging {
     hashAbs
   }
 
+  /**
+   * NaN-safe version of [[java.lang.Double.compare()]] which allows NaN values to be compared
+   * according to semantics where NaN == NaN and NaN > any non-NaN double.
+   */
+  def nanSafeCompareDoubles(x: Double, y: Double): Int = {
+    val xIsNan: Boolean = java.lang.Double.isNaN(x)
+    val yIsNan: Boolean = java.lang.Double.isNaN(y)
+    if ((xIsNan && yIsNan) || (x == y)) 0
+    else if (xIsNan) 1
+    else if (yIsNan) -1
+    else if (x > y) 1
+    else -1
+  }
+
+  /**
+   * NaN-safe version of [[java.lang.Float.compare()]] which allows NaN values to be compared
+   * according to semantics where NaN == NaN and NaN > any non-NaN float.
+   */
+  def nanSafeCompareFloats(x: Float, y: Float): Int = {
+    val xIsNan: Boolean = java.lang.Float.isNaN(x)
+    val yIsNan: Boolean = java.lang.Float.isNaN(y)
+    if ((xIsNan && yIsNan) || (x == y)) 0
+    else if (xIsNan) 1
+    else if (yIsNan) -1
+    else if (x > y) 1
+    else -1
+  }
+
   /** Returns the system properties map that is thread-safe to iterator over. It gets the
     * properties which have been set explicitly, as well as those for which only a default value
     * has been defined. */
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index c7638507c88c60677900a1b1ac429cbca328fd45..8f7e402d5f2a64a1f4806ba2f2b69e7ed598de01 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.util
 
 import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream}
+import java.lang.{Double => JDouble, Float => JFloat}
 import java.net.{BindException, ServerSocket, URI}
 import java.nio.{ByteBuffer, ByteOrder}
 import java.text.DecimalFormatSymbols
@@ -689,4 +690,34 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
     // scalastyle:on println
     assert(buffer.toString === "t circular test circular\n")
   }
+
+  test("nanSafeCompareDoubles") {
+    def shouldMatchDefaultOrder(a: Double, b: Double): Unit = {
+      assert(Utils.nanSafeCompareDoubles(a, b) === JDouble.compare(a, b))
+      assert(Utils.nanSafeCompareDoubles(b, a) === JDouble.compare(b, a))
+    }
+    shouldMatchDefaultOrder(0d, 0d)
+    shouldMatchDefaultOrder(0d, 1d)
+    shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue)
+    assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NaN) === 0)
+    assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.PositiveInfinity) === 1)
+    assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NegativeInfinity) === 1)
+    assert(Utils.nanSafeCompareDoubles(Double.PositiveInfinity, Double.NaN) === -1)
+    assert(Utils.nanSafeCompareDoubles(Double.NegativeInfinity, Double.NaN) === -1)
+  }
+
+  test("nanSafeCompareFloats") {
+    def shouldMatchDefaultOrder(a: Float, b: Float): Unit = {
+      assert(Utils.nanSafeCompareFloats(a, b) === JFloat.compare(a, b))
+      assert(Utils.nanSafeCompareFloats(b, a) === JFloat.compare(b, a))
+    }
+    shouldMatchDefaultOrder(0f, 0f)
+    shouldMatchDefaultOrder(1f, 1f)
+    shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue)
+    assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NaN) === 0)
+    assert(Utils.nanSafeCompareFloats(Float.NaN, Float.PositiveInfinity) === 1)
+    assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NegativeInfinity) === 1)
+    assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1)
+    assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1)
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
index dd505dfa7d758b6e4b2872c13b3bc9a50b2d6d86..dc03e374b51dbd372192fb11fa1e17f85e447dc3 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -47,4 +47,29 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
     forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
     forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
   }
+
+  test("float prefix comparator handles NaN properly") {
+    val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001)
+    val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff)
+    assert(nan1.isNaN)
+    assert(nan2.isNaN)
+    val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1)
+    val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2)
+    assert(nan1Prefix === nan2Prefix)
+    val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue)
+    assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1)
+  }
+
+  test("double prefix comparator handles NaNs properly") {
+    val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
+    val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
+    assert(nan1.isNaN)
+    assert(nan2.isNaN)
+    val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1)
+    val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2)
+    assert(nan1Prefix === nan2Prefix)
+    val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue)
+    assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1)
+  }
+
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 87294a0e21441f1f35ab2708086e4cb5d24daf9d..8cd9e7bc60a03516d9fb87138fa4a784e5a8316a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -215,6 +215,9 @@ public final class UnsafeRow extends MutableRow {
   public void setDouble(int ordinal, double value) {
     assertIndexIsValid(ordinal);
     setNotNullAt(ordinal);
+    if (Double.isNaN(value)) {
+      value = Double.NaN;
+    }
     PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value);
   }
 
@@ -243,6 +246,9 @@ public final class UnsafeRow extends MutableRow {
   public void setFloat(int ordinal, float value) {
     assertIndexIsValid(ordinal);
     setNotNullAt(ordinal);
+    if (Float.isNaN(value)) {
+      value = Float.NaN;
+    }
     PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 2cb64d00935de8cc15fa439809a1d8ad48910739..91449479fa539dfdf787c454ada5435b83f8e942 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -403,20 +403,28 @@ trait Row extends Serializable {
       if (!isNullAt(i)) {
         val o1 = get(i)
         val o2 = other.get(i)
-        if (o1.isInstanceOf[Array[Byte]]) {
-          // handle equality of Array[Byte]
-          val b1 = o1.asInstanceOf[Array[Byte]]
-          if (!o2.isInstanceOf[Array[Byte]] ||
-            !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+        o1 match {
+          case b1: Array[Byte] =>
+            if (!o2.isInstanceOf[Array[Byte]] ||
+                !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+              return false
+            }
+          case f1: Float if java.lang.Float.isNaN(f1) =>
+            if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) {
+              return false
+            }
+          case d1: Double if java.lang.Double.isNaN(d1) =>
+            if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
+              return false
+            }
+          case _ => if (o1 != o2) {
             return false
           }
-        } else if (o1 != o2) {
-          return false
         }
       }
       i += 1
     }
-    return true
+    true
   }
 
   /* ---------------------- utility methods for Scala ---------------------- */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 10f411ff7451affb1cadc142efbd6c8f97c2a145..606f770cb4f7bac96c31cdd1950904e9a6d242fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -194,6 +194,8 @@ class CodeGenContext {
    */
   def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
     case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
+    case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2"
+    case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2"
     case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
     case other => s"$c1.equals($c2)"
   }
@@ -204,6 +206,8 @@ class CodeGenContext {
   def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
     // java boolean doesn't support > or < operator
     case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))"
+    case DoubleType => s"org.apache.spark.util.Utils.nanSafeCompareDoubles($c1, $c2)"
+    case FloatType => s"org.apache.spark.util.Utils.nanSafeCompareFloats($c1, $c2)"
     // use c1 - c2 may overflow
     case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
     case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 40ec3df224ce14750db7f96d3c48713e50273f12..a53ec31ee6a4bc14465621177e020370f76c8b60 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
 
 
 object InterpretedPredicate {
@@ -222,7 +223,9 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
 abstract class BinaryComparison extends BinaryOperator with Predicate {
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    if (ctx.isPrimitiveType(left.dataType)) {
+    if (ctx.isPrimitiveType(left.dataType)
+        && left.dataType != FloatType
+        && left.dataType != DoubleType) {
       // faster version
       defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2")
     } else {
@@ -254,8 +257,15 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
   override def symbol: String = "="
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = {
-    if (left.dataType != BinaryType) input1 == input2
-    else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
+    if (left.dataType == FloatType) {
+      Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
+    } else if (left.dataType == DoubleType) {
+      Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
+    } else if (left.dataType != BinaryType) {
+      input1 == input2
+    } else {
+      java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
+    }
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -280,7 +290,11 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
     } else if (input1 == null || input2 == null) {
       false
     } else {
-      if (left.dataType != BinaryType) {
+      if (left.dataType == FloatType) {
+        Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0
+      } else if (left.dataType == DoubleType) {
+        Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0
+      } else if (left.dataType != BinaryType) {
         input1 == input2
       } else {
         java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
index 986c2ab055386cf7878adc4260540cd67fe3e16c..2a1bf0938e5a832b15daef11ca24e758d2fbfd37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
@@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.util.Utils
 
 /**
  * :: DeveloperApi ::
@@ -37,7 +38,9 @@ class DoubleType private() extends FractionalType {
   @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
   private[sql] val numeric = implicitly[Numeric[Double]]
   private[sql] val fractional = implicitly[Fractional[Double]]
-  private[sql] val ordering = implicitly[Ordering[InternalType]]
+  private[sql] val ordering = new Ordering[Double] {
+    override def compare(x: Double, y: Double): Int = Utils.nanSafeCompareDoubles(x, y)
+  }
   private[sql] val asIntegral = DoubleAsIfIntegral
 
   /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
index 9bd48ece83a1c32b3fda14d32f6f8805fbbc267c..08e22252aef82566284faa32a7f21b588aa942c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
@@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.util.Utils
 
 /**
  * :: DeveloperApi ::
@@ -37,7 +38,9 @@ class FloatType private() extends FractionalType {
   @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] }
   private[sql] val numeric = implicitly[Numeric[Float]]
   private[sql] val fractional = implicitly[Fractional[Float]]
-  private[sql] val ordering = implicitly[Ordering[InternalType]]
+  private[sql] val ordering = new Ordering[Float] {
+    override def compare(x: Float, y: Float): Int = Utils.nanSafeCompareFloats(x, y)
+  }
   private[sql] val asIntegral = FloatAsIfIntegral
 
   /**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index e05218a23aa736dc14444ef021d7fdd0d0695ca6..f4fbc49677ca3c835e4318faf8e32bd7d32bb68d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -17,9 +17,14 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import scala.math._
+
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType}
 
 /**
  * Additional tests for code generation.
@@ -43,6 +48,40 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
     futures.foreach(Await.result(_, 10.seconds))
   }
 
+  // Test GenerateOrdering for all common types. For each type, we construct random input rows that
+  // contain two columns of that type, then for pairs of randomly-generated rows we check that
+  // GenerateOrdering agrees with RowOrdering.
+  (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType =>
+    test(s"GenerateOrdering with $dataType") {
+      val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType))
+      val genOrdering = GenerateOrdering.generate(
+        BoundReference(0, dataType, nullable = true).asc ::
+          BoundReference(1, dataType, nullable = true).asc :: Nil)
+      val rowType = StructType(
+        StructField("a", dataType, nullable = true) ::
+          StructField("b", dataType, nullable = true) :: Nil)
+      val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false)
+      assume(maybeDataGenerator.isDefined)
+      val randGenerator = maybeDataGenerator.get
+      val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
+      for (_ <- 1 to 50) {
+        val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
+        val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
+        withClue(s"a = $a, b = $b") {
+          assert(genOrdering.compare(a, a) === 0)
+          assert(genOrdering.compare(b, b) === 0)
+          assert(rowOrdering.compare(a, a) === 0)
+          assert(rowOrdering.compare(b, b) === 0)
+          assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a)))
+          assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a)))
+          assert(
+            signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)),
+            "Generated and non-generated orderings should agree")
+        }
+      }
+    }
+  }
+
   test("SPARK-8443: split wide projections into blocks due to JVM code size limit") {
     val length = 5000
     val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1)))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 2173a0c25c645c017a843d7f5a5061c820360a7b..0bc2812a5dc83e094f7ad364a51c05bfc63025e6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -136,11 +136,14 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
   }
 
-  private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_))
-  private val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b").map(Literal(_))
-
-  private val equalValues1 = smallValues
-  private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_))
+  private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d).map(Literal(_))
+  private val largeValues =
+    Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_))
+
+  private val equalValues1 =
+    Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
+  private val equalValues2 =
+    Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
 
   test("BinaryComparison: <") {
     for (i <- 0 until smallValues.length) {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index d00aeb4dfbf472fe4e0018479971d18e9292a06d..dff5faf9f6ec8a4d72d7f7b45c02cfac68ec3224 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -316,4 +316,26 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
   }
 
+  test("NaN canonicalization") {
+    val fieldTypes: Array[DataType] = Array(FloatType, DoubleType)
+
+    val row1 = new SpecificMutableRow(fieldTypes)
+    row1.setFloat(0, java.lang.Float.intBitsToFloat(0x7f800001))
+    row1.setDouble(1, java.lang.Double.longBitsToDouble(0x7ff0000000000001L))
+
+    val row2 = new SpecificMutableRow(fieldTypes)
+    row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff))
+    row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL))
+
+    val converter = new UnsafeRowConverter(fieldTypes)
+    val row1Buffer = new Array[Byte](converter.getSizeRequirement(row1))
+    val row2Buffer = new Array[Byte](converter.getSizeRequirement(row2))
+    converter.writeRow(
+      row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length, null)
+    converter.writeRow(
+      row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length, null)
+
+    assert(row1Buffer.toSeq === row2Buffer.toSeq)
+  }
+
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 192cc0a6e5d7c5ae384ec0d8eee6270cceebd3a5..f67f2c60c0e164df1ec2b3d8d3c7095a746ec36c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
 import java.io.File
 
 import scala.language.postfixOps
+import scala.util.Random
 
 import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
 import org.apache.spark.sql.functions._
@@ -742,6 +743,27 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
     df.col("t.``")
   }
 
+  test("SPARK-8797: sort by float column containing NaN should not crash") {
+    val inputData = Seq.fill(10)(Tuple1(Float.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toFloat))
+    val df = Random.shuffle(inputData).toDF("a")
+    df.orderBy("a").collect()
+  }
+
+  test("SPARK-8797: sort by double column containing NaN should not crash") {
+    val inputData = Seq.fill(10)(Tuple1(Double.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toDouble))
+    val df = Random.shuffle(inputData).toDF("a")
+    df.orderBy("a").collect()
+  }
+
+  test("NaN is greater than all other non-NaN numeric values") {
+    val maxDouble = Seq(Double.NaN, Double.PositiveInfinity, Double.MaxValue)
+      .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first()
+    assert(java.lang.Double.isNaN(maxDouble.getDouble(0)))
+    val maxFloat = Seq(Float.NaN, Float.PositiveInfinity, Float.MaxValue)
+      .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first()
+    assert(java.lang.Float.isNaN(maxFloat.getFloat(0)))
+  }
+
   test("SPARK-8072: Better Exception for Duplicate Columns") {
     // only one duplicate column present
     val e = intercept[org.apache.spark.sql.AnalysisException] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index d84b57af9c88288b298f8172f45e15a22ac298c5..7cc6ffd7548d0d7bb2819cbe2da5fe80db07a71f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -73,4 +73,16 @@ class RowSuite extends SparkFunSuite {
       row.getAs[Int]("c")
     }
   }
+
+  test("float NaN == NaN") {
+    val r1 = Row(Float.NaN)
+    val r2 = Row(Float.NaN)
+    assert(r1 === r2)
+  }
+
+  test("double NaN == NaN") {
+    val r1 = Row(Double.NaN)
+    val r2 = Row(Double.NaN)
+    assert(r1 === r2)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
index 4f4c1f28564cb7ebda002cbafce3ee8afc1dfa06..5fe73f7e0b072cb8dedf9334726c6e2db34e4893 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
@@ -83,11 +83,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
     randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
   ) {
     test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
-      val inputData = Seq.fill(1000)(randomDataGenerator()).filter {
-        case d: Double => !d.isNaN
-        case f: Float => !java.lang.Float.isNaN(f)
-        case x => true
-      }
+      val inputData = Seq.fill(1000)(randomDataGenerator())
       val inputDf = TestSQLContext.createDataFrame(
         TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
         StructType(StructField("a", dataType, nullable = true) :: Nil)