From 72bf5199738c7ab0361b2b55eb4f4299048a21fa Mon Sep 17 00:00:00 2001 From: Zhan Zhang <zhanzhang@fb.com> Date: Fri, 9 Dec 2016 16:35:06 +0800 Subject: [PATCH] [SPARK-18637][SQL] Stateful UDF should be considered as nondeterministic Make stateful udf as nondeterministic Add new test cases with both Stateful and Stateless UDF. Without the patch, the test cases will throw exception: 1 did not equal 10 ScalaTestFailureLocation: org.apache.spark.sql.hive.execution.HiveUDFSuite$$anonfun$21 at (HiveUDFSuite.scala:501) org.scalatest.exceptions.TestFailedException: 1 did not equal 10 at org.scalatest.Assertions$class.newAssertionFailedException(Assertions.scala:500) at org.scalatest.FunSuite.newAssertionFailedException(FunSuite.scala:1555) ... Author: Zhan Zhang <zhanzhang@fb.com> Closes #16068 from zhzhan/state. (cherry picked from commit 67587d961d5f94a8639c20cb80127c86bf79d5a8) Signed-off-by: Wenchen Fan <wenchen@databricks.com> --- .../org/apache/spark/sql/hive/hiveUDFs.scala | 4 +- .../sql/hive/execution/HiveUDFSuite.scala | 45 ++++++++++++++++++- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index e30e0f9611..37414ad129 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -59,7 +59,7 @@ private[hive] case class HiveSimpleUDF( @transient private lazy val isUDFDeterministic = { val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) - udfType != null && udfType.deterministic() + udfType != null && udfType.deterministic() && !udfType.stateful() } override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable) @@ -142,7 +142,7 @@ private[hive] case class HiveGenericUDF( @transient private lazy val isUDFDeterministic = { val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) - udfType != null && udfType.deterministic() + udfType != null && udfType.deterministic() && !udfType.stateful() } @transient diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 48adc833f4..4098bb597b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -21,15 +21,17 @@ import java.io.{DataInput, DataOutput, File, PrintWriter} import java.util.{ArrayList, Arrays, Properties} import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.udf.UDAFPercentile +import org.apache.hadoop.hive.ql.exec.UDF +import org.apache.hadoop.hive.ql.udf.{UDAFPercentile, UDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.io.Writable +import org.apache.hadoop.io.{LongWritable, Writable} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.functions.max import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils @@ -487,6 +489,26 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { assert(count4 == 1) sql("DROP TABLE parquet_tmp") } + + test("Hive Stateful UDF") { + withUserDefinedFunction("statefulUDF" -> true, "statelessUDF" -> true) { + sql(s"CREATE TEMPORARY FUNCTION statefulUDF AS '${classOf[StatefulUDF].getName}'") + sql(s"CREATE TEMPORARY FUNCTION statelessUDF AS '${classOf[StatelessUDF].getName}'") + val testData = spark.range(10).repartition(1) + + // Expected Max(s) is 10 as statefulUDF returns the sequence number starting from 1. + checkAnswer(testData.selectExpr("statefulUDF() as s").agg(max($"s")), Row(10)) + + // Expected Max(s) is 5 as statefulUDF returns the sequence number starting from 1, + // and the data is evenly distributed into 2 partitions. + checkAnswer(testData.repartition(2) + .selectExpr("statefulUDF() as s").agg(max($"s")), Row(5)) + + // Expected Max(s) is 1, as stateless UDF is deterministic and foldable and replaced + // by constant 1 by ConstantFolding optimizer. + checkAnswer(testData.selectExpr("statelessUDF() as s").agg(max($"s")), Row(1)) + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { @@ -551,3 +573,22 @@ class PairUDF extends GenericUDF { override def getDisplayString(p1: Array[String]): String = "" } + +@UDFType(stateful = true) +class StatefulUDF extends UDF { + private val result = new LongWritable(0) + + def evaluate(): LongWritable = { + result.set(result.get() + 1) + result + } +} + +class StatelessUDF extends UDF { + private val result = new LongWritable(0) + + def evaluate(): LongWritable = { + result.set(result.get() + 1) + result + } +} -- GitLab