diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a45181712dbdfd3a6c13b281ad27a827ea4b6fd9..7bb2579506a8ab1db08aa8fb497af3691803dc55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -79,6 +79,7 @@ object FunctionRegistry { expression[Explode]("explode"), expression[Greatest]("greatest"), expression[If]("if"), + expression[IsNaN]("isnan"), expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), expression[Least]("least"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7a6fb2b3788ca80c0059fdff731747aed155ac35..2751c8e75f357606f6a4372180861cb5464c50ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -120,6 +120,56 @@ case class InSet(child: Expression, hset: Set[Any]) } } +/** + * Evaluates to `true` if it's NaN or null + */ +case class IsNaN(child: Expression) extends UnaryExpression + with Predicate with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType)) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + true + } else { + child.dataType match { + case DoubleType => value.asInstanceOf[Double].isNaN + case FloatType => value.asInstanceOf[Float].isNaN + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval = child.gen(ctx) + child.dataType match { + case FloatType => + s""" + ${eval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (${eval.isNull}) { + ${ev.primitive} = true; + } else { + ${ev.primitive} = Float.isNaN(${eval.primitive}); + } + """ + case DoubleType => + s""" + ${eval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (${eval.isNull}) { + ${ev.primitive} = true; + } else { + ${ev.primitive} = Double.isNaN(${eval.primitive}); + } + """ + } + } +} case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 188ecef9e7679a2051f19de18a51bb5953c713bf..052abc51af5fd6b9bf84a09aa0ab39740176a13f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{Decimal, IntegerType, BooleanType} +import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType} class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -116,6 +116,16 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { true) } + test("IsNaN") { + checkEvaluation(IsNaN(Literal(Double.NaN)), true) + checkEvaluation(IsNaN(Literal(Float.NaN)), true) + checkEvaluation(IsNaN(Literal(math.log(-3))), true) + checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true) + checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false) + checkEvaluation(IsNaN(Literal(Float.MaxValue)), false) + checkEvaluation(IsNaN(Literal(5.5f)), false) + } + test("INSET") { val hS = HashSet[Any]() + 1 + 2 val nS = HashSet[Any]() + 1 + 2 + null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 10250264625b2c209cd500093630faa6566cb828..221cd04c6d288609661dc34ff328d95f6bf535af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -400,6 +400,14 @@ class Column(protected[sql] val expr: Expression) extends Logging { (this >= lowerBound) && (this <= upperBound) } + /** + * True if the current expression is NaN or null + * + * @group expr_ops + * @since 1.5.0 + */ + def isNaN: Column = IsNaN(expr) + /** * True if the current expression is null. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index fe511c296cfd2ce88e4334ad14fb8b3b8c365617..b56fd9a71b3214fb610a68306475222a977d6fa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -620,7 +620,15 @@ object functions { def explode(e: Column): Column = Explode(e.expr) /** - * Converts a string exprsesion to lower case. + * Return true if the column is NaN or null + * + * @group normal_funcs + * @since 1.5.0 + */ + def isNaN(e: Column): Column = IsNaN(e.expr) + + /** + * Converts a string expression to lower case. * * @group normal_funcs * @since 1.3.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 88bb743ab0bc9147efbf56b828f905bdbf3edd91..8f15479308391b6b7283626e17708739d15bc36d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -201,6 +201,27 @@ class ColumnExpressionSuite extends QueryTest { Row(false, true)) } + test("isNaN") { + val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + Row(Double.NaN, Float.NaN) :: + Row(math.log(-1), math.log(-3).toFloat) :: + Row(null, null) :: + Row(Double.MaxValue, Float.MinValue):: Nil), + StructType(Seq(StructField("a", DoubleType), StructField("b", FloatType)))) + + checkAnswer( + testData.select($"a".isNaN, $"b".isNaN), + Row(true, true) :: Row(true, true) :: Row(true, true) :: Row(false, false) :: Nil) + + checkAnswer( + testData.select(isNaN($"a"), isNaN($"b")), + Row(true, true) :: Row(true, true) :: Row(true, true) :: Row(false, false) :: Nil) + + checkAnswer( + ctx.sql("select isnan(15), isnan('invalid')"), + Row(false, true)) + } + test("===") { checkAnswer( testData2.filter($"a" === 1),