From 27e815c31de26636df089b0b8d9bd678b92d3588 Mon Sep 17 00:00:00 2001
From: Sean Zhong <seanzhong@databricks.com>
Date: Thu, 4 Aug 2016 13:43:25 +0800
Subject: [PATCH] [SPARK-16888][SQL] Implements eval method for expression
 AssertNotNull

## What changes were proposed in this pull request?

Implements `eval()` method for expression `AssertNotNull` so that we can convert local projection on LocalRelation to another LocalRelation.

### Before change:
```
scala> import org.apache.spark.sql.catalyst.dsl.expressions._
scala> import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
scala> import org.apache.spark.sql.Column
scala> case class A(a: Int)
scala> Seq((A(1),2)).toDS().select(new Column(AssertNotNull("_1".attr, Nil))).explain

java.lang.UnsupportedOperationException: Only code-generated evaluation is supported.
  at org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull.eval(objects.scala:850)
  ...
```

### After the change:
```
scala> Seq((A(1),2)).toDS().select(new Column(AssertNotNull("_1".attr, Nil))).explain(true)

== Parsed Logical Plan ==
'Project [assertnotnull('_1) AS assertnotnull(_1)#5]
+- LocalRelation [_1#2, _2#3]

== Analyzed Logical Plan ==
assertnotnull(_1): struct<a:int>
Project [assertnotnull(_1#2) AS assertnotnull(_1)#5]
+- LocalRelation [_1#2, _2#3]

== Optimized Logical Plan ==
LocalRelation [assertnotnull(_1)#5]

== Physical Plan ==
LocalTableScan [assertnotnull(_1)#5]
```

## How was this patch tested?

Unit test.

Author: Sean Zhong <seanzhong@databricks.com>

Closes #14486 from clockfly/assertnotnull_eval.
---
 .../expressions/objects/objects.scala         | 20 ++++++++++++-------
 .../expressions/NullFunctionsSuite.scala      |  8 ++++++++
 2 files changed, 21 insertions(+), 7 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 952a5f3b04..7cb94a7942 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -859,17 +859,23 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
   override def foldable: Boolean = false
   override def nullable: Boolean = false
 
-  override def eval(input: InternalRow): Any =
-    throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+  private val errMsg = "Null value appeared in non-nullable field:" +
+    walkedTypePath.mkString("\n", "\n", "\n") +
+    "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
+    "please try to use scala.Option[_] or other nullable types " +
+    "(e.g. java.lang.Integer instead of int/scala.Int)."
+
+  override def eval(input: InternalRow): Any = {
+    val result = child.eval(input)
+    if (result == null) {
+      throw new RuntimeException(errMsg);
+    }
+    result
+  }
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val childGen = child.genCode(ctx)
 
-    val errMsg = "Null value appeared in non-nullable field:" +
-      walkedTypePath.mkString("\n", "\n", "\n") +
-      "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
-      "please try to use scala.Option[_] or other nullable types " +
-      "(e.g. java.lang.Integer instead of int/scala.Int)."
     val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
 
     val code = s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
index 712fe35f47..e736379930 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
 import org.apache.spark.sql.types._
 
 class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -45,6 +46,13 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     }
   }
 
+  test("AssertNotNUll") {
+    val ex = intercept[RuntimeException] {
+      evaluate(AssertNotNull(Literal(null), Seq.empty[String]))
+    }.getMessage
+    assert(ex.contains("Null value appeared in non-nullable field"))
+  }
+
   test("IsNaN") {
     checkEvaluation(IsNaN(Literal(Double.NaN)), true)
     checkEvaluation(IsNaN(Literal(Float.NaN)), true)
-- 
GitLab