Skip to content
Snippets Groups Projects
Commit 72bf5199 authored by Zhan Zhang's avatar Zhan Zhang Committed by Wenchen Fan
Browse files

[SPARK-18637][SQL] Stateful UDF should be considered as nondeterministic


Make stateful udf as nondeterministic

Add new test cases with both Stateful and Stateless UDF.
Without the patch, the test cases will throw exception:

1 did not equal 10
ScalaTestFailureLocation: org.apache.spark.sql.hive.execution.HiveUDFSuite$$anonfun$21 at (HiveUDFSuite.scala:501)
org.scalatest.exceptions.TestFailedException: 1 did not equal 10
        at org.scalatest.Assertions$class.newAssertionFailedException(Assertions.scala:500)
        at org.scalatest.FunSuite.newAssertionFailedException(FunSuite.scala:1555)
        ...

Author: Zhan Zhang <zhanzhang@fb.com>

Closes #16068 from zhzhan/state.

(cherry picked from commit 67587d96)
Signed-off-by: default avatarWenchen Fan <wenchen@databricks.com>
parent 2c88e1dc
No related branches found
No related tags found
No related merge requests found
......@@ -59,7 +59,7 @@ private[hive] case class HiveSimpleUDF(
@transient
private lazy val isUDFDeterministic = {
val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
udfType != null && udfType.deterministic()
udfType != null && udfType.deterministic() && !udfType.stateful()
}
override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable)
......@@ -142,7 +142,7 @@ private[hive] case class HiveGenericUDF(
@transient
private lazy val isUDFDeterministic = {
val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
udfType != null && udfType.deterministic()
udfType != null && udfType.deterministic() && !udfType.stateful()
}
@transient
......
......@@ -21,15 +21,17 @@ import java.io.{DataInput, DataOutput, File, PrintWriter}
import java.util.{ArrayList, Arrays, Properties}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.ql.udf.UDAFPercentile
import org.apache.hadoop.hive.ql.exec.UDF
import org.apache.hadoop.hive.ql.udf.{UDAFPercentile, UDFType}
import org.apache.hadoop.hive.ql.udf.generic._
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats}
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.io.Writable
import org.apache.hadoop.io.{LongWritable, Writable}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.functions.max
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils
......@@ -487,6 +489,26 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
assert(count4 == 1)
sql("DROP TABLE parquet_tmp")
}
test("Hive Stateful UDF") {
withUserDefinedFunction("statefulUDF" -> true, "statelessUDF" -> true) {
sql(s"CREATE TEMPORARY FUNCTION statefulUDF AS '${classOf[StatefulUDF].getName}'")
sql(s"CREATE TEMPORARY FUNCTION statelessUDF AS '${classOf[StatelessUDF].getName}'")
val testData = spark.range(10).repartition(1)
// Expected Max(s) is 10 as statefulUDF returns the sequence number starting from 1.
checkAnswer(testData.selectExpr("statefulUDF() as s").agg(max($"s")), Row(10))
// Expected Max(s) is 5 as statefulUDF returns the sequence number starting from 1,
// and the data is evenly distributed into 2 partitions.
checkAnswer(testData.repartition(2)
.selectExpr("statefulUDF() as s").agg(max($"s")), Row(5))
// Expected Max(s) is 1, as stateless UDF is deterministic and foldable and replaced
// by constant 1 by ConstantFolding optimizer.
checkAnswer(testData.selectExpr("statelessUDF() as s").agg(max($"s")), Row(1))
}
}
}
class TestPair(x: Int, y: Int) extends Writable with Serializable {
......@@ -551,3 +573,22 @@ class PairUDF extends GenericUDF {
override def getDisplayString(p1: Array[String]): String = ""
}
@UDFType(stateful = true)
class StatefulUDF extends UDF {
private val result = new LongWritable(0)
def evaluate(): LongWritable = {
result.set(result.get() + 1)
result
}
}
class StatelessUDF extends UDF {
private val result = new LongWritable(0)
def evaluate(): LongWritable = {
result.set(result.get() + 1)
result
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment