diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 438742565c51d58aeee657d7915f6bece8f8cc74..bf1bc5dffba78e2b1fd3ca805979d61d720574d3 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -23,6 +23,7 @@ import com.google.common.primitives.UnsignedBytes; import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.Utils; @Private public class PrefixComparators { @@ -82,7 +83,7 @@ public class PrefixComparators { public int compare(long aPrefix, long bPrefix) { float a = Float.intBitsToFloat((int) aPrefix); float b = Float.intBitsToFloat((int) bPrefix); - return (a < b) ? -1 : (a > b) ? 1 : 0; + return Utils.nanSafeCompareFloats(a, b); } public long computePrefix(float value) { @@ -97,7 +98,7 @@ public class PrefixComparators { public int compare(long aPrefix, long bPrefix) { double a = Double.longBitsToDouble(aPrefix); double b = Double.longBitsToDouble(bPrefix); - return (a < b) ? -1 : (a > b) ? 1 : 0; + return Utils.nanSafeCompareDoubles(a, b); } public long computePrefix(double value) { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e6374f17d858f2b152d087c0dfc010b4289094db..c5816949cd360ce8907f0b25c6f94dfd45fc6f73 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1586,6 +1586,34 @@ private[spark] object Utils extends Logging { hashAbs } + /** + * NaN-safe version of [[java.lang.Double.compare()]] which allows NaN values to be compared + * according to semantics where NaN == NaN and NaN > any non-NaN double. + */ + def nanSafeCompareDoubles(x: Double, y: Double): Int = { + val xIsNan: Boolean = java.lang.Double.isNaN(x) + val yIsNan: Boolean = java.lang.Double.isNaN(y) + if ((xIsNan && yIsNan) || (x == y)) 0 + else if (xIsNan) 1 + else if (yIsNan) -1 + else if (x > y) 1 + else -1 + } + + /** + * NaN-safe version of [[java.lang.Float.compare()]] which allows NaN values to be compared + * according to semantics where NaN == NaN and NaN > any non-NaN float. + */ + def nanSafeCompareFloats(x: Float, y: Float): Int = { + val xIsNan: Boolean = java.lang.Float.isNaN(x) + val yIsNan: Boolean = java.lang.Float.isNaN(y) + if ((xIsNan && yIsNan) || (x == y)) 0 + else if (xIsNan) 1 + else if (yIsNan) -1 + else if (x > y) 1 + else -1 + } + /** Returns the system properties map that is thread-safe to iterator over. It gets the * properties which have been set explicitly, as well as those for which only a default value * has been defined. */ diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index c7638507c88c60677900a1b1ac429cbca328fd45..8f7e402d5f2a64a1f4806ba2f2b69e7ed598de01 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} +import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols @@ -689,4 +690,34 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { // scalastyle:on println assert(buffer.toString === "t circular test circular\n") } + + test("nanSafeCompareDoubles") { + def shouldMatchDefaultOrder(a: Double, b: Double): Unit = { + assert(Utils.nanSafeCompareDoubles(a, b) === JDouble.compare(a, b)) + assert(Utils.nanSafeCompareDoubles(b, a) === JDouble.compare(b, a)) + } + shouldMatchDefaultOrder(0d, 0d) + shouldMatchDefaultOrder(0d, 1d) + shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NaN) === 0) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.PositiveInfinity) === 1) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NegativeInfinity) === 1) + assert(Utils.nanSafeCompareDoubles(Double.PositiveInfinity, Double.NaN) === -1) + assert(Utils.nanSafeCompareDoubles(Double.NegativeInfinity, Double.NaN) === -1) + } + + test("nanSafeCompareFloats") { + def shouldMatchDefaultOrder(a: Float, b: Float): Unit = { + assert(Utils.nanSafeCompareFloats(a, b) === JFloat.compare(a, b)) + assert(Utils.nanSafeCompareFloats(b, a) === JFloat.compare(b, a)) + } + shouldMatchDefaultOrder(0f, 0f) + shouldMatchDefaultOrder(1f, 1f) + shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NaN) === 0) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.PositiveInfinity) === 1) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NegativeInfinity) === 1) + assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1) + assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index dd505dfa7d758b6e4b2872c13b3bc9a50b2d6d86..dc03e374b51dbd372192fb11fa1e17f85e447dc3 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -47,4 +47,29 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) } forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } + + test("float prefix comparator handles NaN properly") { + val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001) + val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff) + assert(nan1.isNaN) + assert(nan2.isNaN) + val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1) + val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2) + assert(nan1Prefix === nan2Prefix) + val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue) + assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1) + } + + test("double prefix comparator handles NaNs properly") { + val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) + val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) + assert(nan1.isNaN) + assert(nan2.isNaN) + val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1) + val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2) + assert(nan1Prefix === nan2Prefix) + val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue) + assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) + } + } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 87294a0e21441f1f35ab2708086e4cb5d24daf9d..8cd9e7bc60a03516d9fb87138fa4a784e5a8316a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -215,6 +215,9 @@ public final class UnsafeRow extends MutableRow { public void setDouble(int ordinal, double value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); + if (Double.isNaN(value)) { + value = Double.NaN; + } PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); } @@ -243,6 +246,9 @@ public final class UnsafeRow extends MutableRow { public void setFloat(int ordinal, float value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); + if (Float.isNaN(value)) { + value = Float.NaN; + } PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 2cb64d00935de8cc15fa439809a1d8ad48910739..91449479fa539dfdf787c454ada5435b83f8e942 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -403,20 +403,28 @@ trait Row extends Serializable { if (!isNullAt(i)) { val o1 = get(i) val o2 = other.get(i) - if (o1.isInstanceOf[Array[Byte]]) { - // handle equality of Array[Byte] - val b1 = o1.asInstanceOf[Array[Byte]] - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { return false } - } else if (o1 != o2) { - return false } } i += 1 } - return true + true } /* ---------------------- utility methods for Scala ---------------------- */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 10f411ff7451affb1cadc142efbd6c8f97c2a145..606f770cb4f7bac96c31cdd1950904e9a6d242fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -194,6 +194,8 @@ class CodeGenContext { */ def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { case BinaryType => s"java.util.Arrays.equals($c1, $c2)" + case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" + case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" case other => s"$c1.equals($c2)" } @@ -204,6 +206,8 @@ class CodeGenContext { def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { // java boolean doesn't support > or < operator case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" + case DoubleType => s"org.apache.spark.util.Utils.nanSafeCompareDoubles($c1, $c2)" + case FloatType => s"org.apache.spark.util.Utils.nanSafeCompareFloats($c1, $c2)" // use c1 - c2 may overflow case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 40ec3df224ce14750db7f96d3c48713e50273f12..a53ec31ee6a4bc14465621177e020370f76c8b60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object InterpretedPredicate { @@ -222,7 +223,9 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P abstract class BinaryComparison extends BinaryOperator with Predicate { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - if (ctx.isPrimitiveType(left.dataType)) { + if (ctx.isPrimitiveType(left.dataType) + && left.dataType != FloatType + && left.dataType != DoubleType) { // faster version defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2") } else { @@ -254,8 +257,15 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison override def symbol: String = "=" protected override def nullSafeEval(input1: Any, input2: Any): Any = { - if (left.dataType != BinaryType) input1 == input2 - else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) + if (left.dataType == FloatType) { + Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0 + } else if (left.dataType == DoubleType) { + Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0 + } else if (left.dataType != BinaryType) { + input1 == input2 + } else { + java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) + } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -280,7 +290,11 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } else if (input1 == null || input2 == null) { false } else { - if (left.dataType != BinaryType) { + if (left.dataType == FloatType) { + Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0 + } else if (left.dataType == DoubleType) { + Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0 + } else if (left.dataType != BinaryType) { input1 == input2 } else { java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index 986c2ab055386cf7878adc4260540cd67fe3e16c..2a1bf0938e5a832b15daef11ca24e758d2fbfd37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -37,7 +38,9 @@ class DoubleType private() extends FractionalType { @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Double]] private[sql] val fractional = implicitly[Fractional[Double]] - private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val ordering = new Ordering[Double] { + override def compare(x: Double, y: Double): Int = Utils.nanSafeCompareDoubles(x, y) + } private[sql] val asIntegral = DoubleAsIfIntegral /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 9bd48ece83a1c32b3fda14d32f6f8805fbbc267c..08e22252aef82566284faa32a7f21b588aa942c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -37,7 +38,9 @@ class FloatType private() extends FractionalType { @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Float]] private[sql] val fractional = implicitly[Fractional[Float]] - private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val ordering = new Ordering[Float] { + override def compare(x: Float, y: Float): Int = Utils.nanSafeCompareFloats(x, y) + } private[sql] val asIntegral = FloatAsIfIntegral /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index e05218a23aa736dc14444ef021d7fdd0d0695ca6..f4fbc49677ca3c835e4318faf8e32bd7d32bb68d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.catalyst.expressions +import scala.math._ + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType} /** * Additional tests for code generation. @@ -43,6 +48,40 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { futures.foreach(Await.result(_, 10.seconds)) } + // Test GenerateOrdering for all common types. For each type, we construct random input rows that + // contain two columns of that type, then for pairs of randomly-generated rows we check that + // GenerateOrdering agrees with RowOrdering. + (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType => + test(s"GenerateOrdering with $dataType") { + val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType)) + val genOrdering = GenerateOrdering.generate( + BoundReference(0, dataType, nullable = true).asc :: + BoundReference(1, dataType, nullable = true).asc :: Nil) + val rowType = StructType( + StructField("a", dataType, nullable = true) :: + StructField("b", dataType, nullable = true) :: Nil) + val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false) + assume(maybeDataGenerator.isDefined) + val randGenerator = maybeDataGenerator.get + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) + for (_ <- 1 to 50) { + val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + withClue(s"a = $a, b = $b") { + assert(genOrdering.compare(a, a) === 0) + assert(genOrdering.compare(b, b) === 0) + assert(rowOrdering.compare(a, a) === 0) + assert(rowOrdering.compare(b, b) === 0) + assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) + assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) + assert( + signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), + "Generated and non-generated orderings should agree") + } + } + } + } + test("SPARK-8443: split wide projections into blocks due to JVM code size limit") { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 2173a0c25c645c017a843d7f5a5061c820360a7b..0bc2812a5dc83e094f7ad364a51c05bfc63025e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -136,11 +136,14 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) } - private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_)) - private val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b").map(Literal(_)) - - private val equalValues1 = smallValues - private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_)) + private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d).map(Literal(_)) + private val largeValues = + Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_)) + + private val equalValues1 = + Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_)) + private val equalValues2 = + Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_)) test("BinaryComparison: <") { for (i <- 0 until smallValues.length) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index d00aeb4dfbf472fe4e0018479971d18e9292a06d..dff5faf9f6ec8a4d72d7f7b45c02cfac68ec3224 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -316,4 +316,26 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } + test("NaN canonicalization") { + val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) + + val row1 = new SpecificMutableRow(fieldTypes) + row1.setFloat(0, java.lang.Float.intBitsToFloat(0x7f800001)) + row1.setDouble(1, java.lang.Double.longBitsToDouble(0x7ff0000000000001L)) + + val row2 = new SpecificMutableRow(fieldTypes) + row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) + row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) + + val converter = new UnsafeRowConverter(fieldTypes) + val row1Buffer = new Array[Byte](converter.getSizeRequirement(row1)) + val row2Buffer = new Array[Byte](converter.getSizeRequirement(row2)) + converter.writeRow( + row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length, null) + converter.writeRow( + row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length, null) + + assert(row1Buffer.toSeq === row2Buffer.toSeq) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 192cc0a6e5d7c5ae384ec0d8eee6270cceebd3a5..f67f2c60c0e164df1ec2b3d8d3c7095a746ec36c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.File import scala.language.postfixOps +import scala.util.Random import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.functions._ @@ -742,6 +743,27 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { df.col("t.``") } + test("SPARK-8797: sort by float column containing NaN should not crash") { + val inputData = Seq.fill(10)(Tuple1(Float.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toFloat)) + val df = Random.shuffle(inputData).toDF("a") + df.orderBy("a").collect() + } + + test("SPARK-8797: sort by double column containing NaN should not crash") { + val inputData = Seq.fill(10)(Tuple1(Double.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toDouble)) + val df = Random.shuffle(inputData).toDF("a") + df.orderBy("a").collect() + } + + test("NaN is greater than all other non-NaN numeric values") { + val maxDouble = Seq(Double.NaN, Double.PositiveInfinity, Double.MaxValue) + .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first() + assert(java.lang.Double.isNaN(maxDouble.getDouble(0))) + val maxFloat = Seq(Float.NaN, Float.PositiveInfinity, Float.MaxValue) + .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first() + assert(java.lang.Float.isNaN(maxFloat.getFloat(0))) + } + test("SPARK-8072: Better Exception for Duplicate Columns") { // only one duplicate column present val e = intercept[org.apache.spark.sql.AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index d84b57af9c88288b298f8172f45e15a22ac298c5..7cc6ffd7548d0d7bb2819cbe2da5fe80db07a71f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -73,4 +73,16 @@ class RowSuite extends SparkFunSuite { row.getAs[Int]("c") } } + + test("float NaN == NaN") { + val r1 = Row(Float.NaN) + val r2 = Row(Float.NaN) + assert(r1 === r2) + } + + test("double NaN == NaN") { + val r1 = Row(Double.NaN) + val r2 = Row(Double.NaN) + assert(r1 === r2) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 4f4c1f28564cb7ebda002cbafce3ee8afc1dfa06..5fe73f7e0b072cb8dedf9334726c6e2db34e4893 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -83,11 +83,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { - val inputData = Seq.fill(1000)(randomDataGenerator()).filter { - case d: Double => !d.isNaN - case f: Float => !java.lang.Float.isNaN(f) - case x => true - } + val inputData = Seq.fill(1000)(randomDataGenerator()) val inputDf = TestSQLContext.createDataFrame( TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil)