From 72bf5199738c7ab0361b2b55eb4f4299048a21fa Mon Sep 17 00:00:00 2001
From: Zhan Zhang <zhanzhang@fb.com>
Date: Fri, 9 Dec 2016 16:35:06 +0800
Subject: [PATCH] [SPARK-18637][SQL] Stateful UDF should be considered as
 nondeterministic

Make stateful udf as nondeterministic

Add new test cases with both Stateful and Stateless UDF.
Without the patch, the test cases will throw exception:

1 did not equal 10
ScalaTestFailureLocation: org.apache.spark.sql.hive.execution.HiveUDFSuite$$anonfun$21 at (HiveUDFSuite.scala:501)
org.scalatest.exceptions.TestFailedException: 1 did not equal 10
        at org.scalatest.Assertions$class.newAssertionFailedException(Assertions.scala:500)
        at org.scalatest.FunSuite.newAssertionFailedException(FunSuite.scala:1555)
        ...

Author: Zhan Zhang <zhanzhang@fb.com>

Closes #16068 from zhzhan/state.

(cherry picked from commit 67587d961d5f94a8639c20cb80127c86bf79d5a8)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
---
 .../org/apache/spark/sql/hive/hiveUDFs.scala  |  4 +-
 .../sql/hive/execution/HiveUDFSuite.scala     | 45 ++++++++++++++++++-
 2 files changed, 45 insertions(+), 4 deletions(-)

diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index e30e0f9611..37414ad129 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -59,7 +59,7 @@ private[hive] case class HiveSimpleUDF(
   @transient
   private lazy val isUDFDeterministic = {
     val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
-    udfType != null && udfType.deterministic()
+    udfType != null && udfType.deterministic() && !udfType.stateful()
   }
 
   override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable)
@@ -142,7 +142,7 @@ private[hive] case class HiveGenericUDF(
   @transient
   private lazy val isUDFDeterministic = {
     val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
-    udfType != null && udfType.deterministic()
+    udfType != null && udfType.deterministic() && !udfType.stateful()
   }
 
   @transient
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index 48adc833f4..4098bb597b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -21,15 +21,17 @@ import java.io.{DataInput, DataOutput, File, PrintWriter}
 import java.util.{ArrayList, Arrays, Properties}
 
 import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.ql.udf.UDAFPercentile
+import org.apache.hadoop.hive.ql.exec.UDF
+import org.apache.hadoop.hive.ql.udf.{UDAFPercentile, UDFType}
 import org.apache.hadoop.hive.ql.udf.generic._
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
 import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats}
 import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
 import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
-import org.apache.hadoop.io.Writable
+import org.apache.hadoop.io.{LongWritable, Writable}
 
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.functions.max
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.test.SQLTestUtils
 import org.apache.spark.util.Utils
@@ -487,6 +489,26 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
     assert(count4 == 1)
     sql("DROP TABLE parquet_tmp")
   }
+
+  test("Hive Stateful UDF") {
+    withUserDefinedFunction("statefulUDF" -> true, "statelessUDF" -> true) {
+      sql(s"CREATE TEMPORARY FUNCTION statefulUDF AS '${classOf[StatefulUDF].getName}'")
+      sql(s"CREATE TEMPORARY FUNCTION statelessUDF AS '${classOf[StatelessUDF].getName}'")
+      val testData = spark.range(10).repartition(1)
+
+      // Expected Max(s) is 10 as statefulUDF returns the sequence number starting from 1.
+      checkAnswer(testData.selectExpr("statefulUDF() as s").agg(max($"s")), Row(10))
+
+      // Expected Max(s) is 5 as statefulUDF returns the sequence number starting from 1,
+      // and the data is evenly distributed into 2 partitions.
+      checkAnswer(testData.repartition(2)
+        .selectExpr("statefulUDF() as s").agg(max($"s")), Row(5))
+
+      // Expected Max(s) is 1, as stateless UDF is deterministic and foldable and replaced
+      // by constant 1 by ConstantFolding optimizer.
+      checkAnswer(testData.selectExpr("statelessUDF() as s").agg(max($"s")), Row(1))
+    }
+  }
 }
 
 class TestPair(x: Int, y: Int) extends Writable with Serializable {
@@ -551,3 +573,22 @@ class PairUDF extends GenericUDF {
 
   override def getDisplayString(p1: Array[String]): String = ""
 }
+
+@UDFType(stateful = true)
+class StatefulUDF extends UDF {
+  private val result = new LongWritable(0)
+
+  def evaluate(): LongWritable = {
+    result.set(result.get() + 1)
+    result
+  }
+}
+
+class StatelessUDF extends UDF {
+  private val result = new LongWritable(0)
+
+  def evaluate(): LongWritable = {
+    result.set(result.get() + 1)
+    result
+  }
+}
-- 
GitLab