Skip to content
Snippets Groups Projects
Commit 9cf56c96 authored by Yin Huai's avatar Yin Huai
Browse files

[SPARK-11469][SQL] Allow users to define nondeterministic udfs.

This is the first task (https://issues.apache.org/jira/browse/SPARK-11469) of https://issues.apache.org/jira/browse/SPARK-11438

Author: Yin Huai <yhuai@databricks.com>

Closes #9393 from yhuai/udfNondeterministic.
parent efaa4721
No related branches found
No related tags found
No related merge requests found
...@@ -112,6 +112,53 @@ object MimaExcludes { ...@@ -112,6 +112,53 @@ object MimaExcludes {
"org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"),
ProblemFilters.exclude[MissingClassProblem]( ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.rdd.MapPartitionsWithPreparationRDD$") "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$")
) ++ Seq(
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$2"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$3"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$4"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$5"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$6"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$7"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$8"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$9"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$10"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$11"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$12"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$13"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$14"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$15"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$16"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$17"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$18"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$19"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$20"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$21"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$22"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$23"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24")
) )
case v if v.startsWith("1.5") => case v if v.startsWith("1.5") =>
Seq( Seq(
......
...@@ -30,13 +30,18 @@ case class ScalaUDF( ...@@ -30,13 +30,18 @@ case class ScalaUDF(
function: AnyRef, function: AnyRef,
dataType: DataType, dataType: DataType,
children: Seq[Expression], children: Seq[Expression],
inputTypes: Seq[DataType] = Nil) inputTypes: Seq[DataType] = Nil,
isDeterministic: Boolean = true)
extends Expression with ImplicitCastInputTypes with CodegenFallback { extends Expression with ImplicitCastInputTypes with CodegenFallback {
override def nullable: Boolean = true override def nullable: Boolean = true
override def toString: String = s"UDF(${children.mkString(",")})" override def toString: String = s"UDF(${children.mkString(",")})"
override def foldable: Boolean = deterministic && children.forall(_.foldable)
override def deterministic: Boolean = isDeterministic && children.forall(_.deterministic)
// scalastyle:off // scalastyle:off
/** This method has been generated by this script /** This method has been generated by this script
......
...@@ -44,11 +44,20 @@ import org.apache.spark.sql.types.DataType ...@@ -44,11 +44,20 @@ import org.apache.spark.sql.types.DataType
case class UserDefinedFunction protected[sql] ( case class UserDefinedFunction protected[sql] (
f: AnyRef, f: AnyRef,
dataType: DataType, dataType: DataType,
inputTypes: Seq[DataType] = Nil) { inputTypes: Seq[DataType] = Nil,
deterministic: Boolean = true) {
def apply(exprs: Column*): Column = { def apply(exprs: Column*): Column = {
Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes)) Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes, deterministic))
} }
protected[sql] def builder: Seq[Expression] => ScalaUDF = {
(exprs: Seq[Expression]) =>
ScalaUDF(f, dataType, exprs, inputTypes, deterministic)
}
def nondeterministic: UserDefinedFunction =
UserDefinedFunction(f, dataType, inputTypes, deterministic = false)
} }
/** /**
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
package org.apache.spark.sql package org.apache.spark.sql
import org.apache.spark.sql.catalyst.expressions.ScalaUDF
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.test.SQLTestData._
...@@ -191,4 +193,107 @@ class UDFSuite extends QueryTest with SharedSQLContext { ...@@ -191,4 +193,107 @@ class UDFSuite extends QueryTest with SharedSQLContext {
// pass a decimal to intExpected. // pass a decimal to intExpected.
assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
} }
private def checkNumUDFs(df: DataFrame, expectedNumUDFs: Int): Unit = {
val udfs = df.queryExecution.optimizedPlan.collect {
case p: logical.Project => p.projectList.flatMap {
case e => e.collect {
case udf: ScalaUDF => udf
}
}
}.flatten
assert(udfs.length === expectedNumUDFs)
}
test("foldable udf") {
import org.apache.spark.sql.functions._
val myUDF = udf((x: Int) => x + 1)
{
val df = sql("SELECT 1 as a")
.select(col("a"), myUDF(col("a")).as("b"))
.select(col("a"), col("b"), myUDF(col("b")).as("c"))
checkNumUDFs(df, 0)
checkAnswer(df, Row(1, 2, 3))
}
}
test("nondeterministic udf: using UDFRegistration") {
import org.apache.spark.sql.functions._
val myUDF = sqlContext.udf.register("plusOne1", (x: Int) => x + 1)
sqlContext.udf.register("plusOne2", myUDF.nondeterministic)
{
val df = sqlContext.range(1, 2).select(col("id").as("a"))
.select(col("a"), myUDF(col("a")).as("b"))
.select(col("a"), col("b"), myUDF(col("b")).as("c"))
checkNumUDFs(df, 3)
checkAnswer(df, Row(1, 2, 3))
}
{
val df = sqlContext.range(1, 2).select(col("id").as("a"))
.select(col("a"), callUDF("plusOne1", col("a")).as("b"))
.select(col("a"), col("b"), callUDF("plusOne1", col("b")).as("c"))
checkNumUDFs(df, 3)
checkAnswer(df, Row(1, 2, 3))
}
{
val df = sqlContext.range(1, 2).select(col("id").as("a"))
.select(col("a"), myUDF.nondeterministic(col("a")).as("b"))
.select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c"))
checkNumUDFs(df, 2)
checkAnswer(df, Row(1, 2, 3))
}
{
val df = sqlContext.range(1, 2).select(col("id").as("a"))
.select(col("a"), callUDF("plusOne2", col("a")).as("b"))
.select(col("a"), col("b"), callUDF("plusOne2", col("b")).as("c"))
checkNumUDFs(df, 2)
checkAnswer(df, Row(1, 2, 3))
}
}
test("nondeterministic udf: using udf function") {
import org.apache.spark.sql.functions._
val myUDF = udf((x: Int) => x + 1)
{
val df = sqlContext.range(1, 2).select(col("id").as("a"))
.select(col("a"), myUDF(col("a")).as("b"))
.select(col("a"), col("b"), myUDF(col("b")).as("c"))
checkNumUDFs(df, 3)
checkAnswer(df, Row(1, 2, 3))
}
{
val df = sqlContext.range(1, 2).select(col("id").as("a"))
.select(col("a"), myUDF.nondeterministic(col("a")).as("b"))
.select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c"))
checkNumUDFs(df, 2)
checkAnswer(df, Row(1, 2, 3))
}
{
// nondeterministicUDF will not be foldable.
val df = sql("SELECT 1 as a")
.select(col("a"), myUDF.nondeterministic(col("a")).as("b"))
.select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c"))
checkNumUDFs(df, 2)
checkAnswer(df, Row(1, 2, 3))
}
}
test("override a registered udf") {
sqlContext.udf.register("intExpected", (x: Int) => x)
assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
sqlContext.udf.register("intExpected", (x: Int) => x + 1)
assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 2)
}
} }
...@@ -381,7 +381,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { ...@@ -381,7 +381,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
sqlContext.udf.register("div0", (x: Int) => x / 0) sqlContext.udf.register("div0", (x: Int) => x / 0)
withTempPath { dir => withTempPath { dir =>
intercept[org.apache.spark.SparkException] { intercept[org.apache.spark.SparkException] {
sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath)
} }
val path = new Path(dir.getCanonicalPath, "_temporary") val path = new Path(dir.getCanonicalPath, "_temporary")
val fs = path.getFileSystem(hadoopConfiguration) val fs = path.getFileSystem(hadoopConfiguration)
...@@ -405,7 +405,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { ...@@ -405,7 +405,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
sqlContext.udf.register("div0", (x: Int) => x / 0) sqlContext.udf.register("div0", (x: Int) => x / 0)
withTempPath { dir => withTempPath { dir =>
intercept[org.apache.spark.SparkException] { intercept[org.apache.spark.SparkException] {
sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath)
} }
val path = new Path(dir.getCanonicalPath, "_temporary") val path = new Path(dir.getCanonicalPath, "_temporary")
val fs = path.getFileSystem(hadoopConfiguration) val fs = path.getFileSystem(hadoopConfiguration)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment