Skip to content
Snippets Groups Projects
Commit 27e815c3 authored by Sean Zhong's avatar Sean Zhong Committed by Wenchen Fan
Browse files

[SPARK-16888][SQL] Implements eval method for expression AssertNotNull

## What changes were proposed in this pull request?

Implements `eval()` method for expression `AssertNotNull` so that we can convert local projection on LocalRelation to another LocalRelation.

### Before change:
```
scala> import org.apache.spark.sql.catalyst.dsl.expressions._
scala> import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
scala> import org.apache.spark.sql.Column
scala> case class A(a: Int)
scala> Seq((A(1),2)).toDS().select(new Column(AssertNotNull("_1".attr, Nil))).explain

java.lang.UnsupportedOperationException: Only code-generated evaluation is supported.
  at org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull.eval(objects.scala:850)
  ...
```

### After the change:
```
scala> Seq((A(1),2)).toDS().select(new Column(AssertNotNull("_1".attr, Nil))).explain(true)

== Parsed Logical Plan ==
'Project [assertnotnull('_1) AS assertnotnull(_1)#5]
+- LocalRelation [_1#2, _2#3]

== Analyzed Logical Plan ==
assertnotnull(_1): struct<a:int>
Project [assertnotnull(_1#2) AS assertnotnull(_1)#5]
+- LocalRelation [_1#2, _2#3]

== Optimized Logical Plan ==
LocalRelation [assertnotnull(_1)#5]

== Physical Plan ==
LocalTableScan [assertnotnull(_1)#5]
```

## How was this patch tested?

Unit test.

Author: Sean Zhong <seanzhong@databricks.com>

Closes #14486 from clockfly/assertnotnull_eval.
parent 780c7224
No related branches found
No related tags found
No related merge requests found
......@@ -859,17 +859,23 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
override def foldable: Boolean = false
override def nullable: Boolean = false
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
private val errMsg = "Null value appeared in non-nullable field:" +
walkedTypePath.mkString("\n", "\n", "\n") +
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
"please try to use scala.Option[_] or other nullable types " +
"(e.g. java.lang.Integer instead of int/scala.Int)."
override def eval(input: InternalRow): Any = {
val result = child.eval(input)
if (result == null) {
throw new RuntimeException(errMsg);
}
result
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
val errMsg = "Null value appeared in non-nullable field:" +
walkedTypePath.mkString("\n", "\n", "\n") +
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
"please try to use scala.Option[_] or other nullable types " +
"(e.g. java.lang.Integer instead of int/scala.Int)."
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val code = s"""
......
......@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.types._
class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
......@@ -45,6 +46,13 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
test("AssertNotNUll") {
val ex = intercept[RuntimeException] {
evaluate(AssertNotNull(Literal(null), Seq.empty[String]))
}.getMessage
assert(ex.contains("Null value appeared in non-nullable field"))
}
test("IsNaN") {
checkEvaluation(IsNaN(Literal(Double.NaN)), true)
checkEvaluation(IsNaN(Literal(Float.NaN)), true)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment