Skip to content
Snippets Groups Projects
Commit b12a76a4 authored by Takeshi YAMAMURO's avatar Takeshi YAMAMURO Committed by gatorsmile
Browse files

[SPARK-19338][SQL] Add UDF names in explain


## What changes were proposed in this pull request?
This pr added a variable for a UDF name in `ScalaUDF`.
Then, if the variable filled, `DataFrame#explain` prints the name.

## How was this patch tested?
Added a test in `UDFSuite`.

Author: Takeshi YAMAMURO <linguin.m.s@gmail.com>

Closes #16707 from maropu/SPARK-19338.

(cherry picked from commit 9f523d31)
Signed-off-by: default avatargatorsmile <gatorsmile@gmail.com>
parent 0d7e3852
No related branches found
No related tags found
No related merge requests found
......@@ -1904,7 +1904,7 @@ class Analyzer(
case p => p transformExpressionsUp {
case udf @ ScalaUDF(func, _, inputs, _) =>
case udf @ ScalaUDF(func, _, inputs, _, _) =>
val parameterTypes = ScalaReflection.getParameterTypes(func)
assert(parameterTypes.length == inputs.length)
......
......@@ -35,17 +35,20 @@ import org.apache.spark.sql.types.DataType
* not want to perform coercion, simply use "Nil". Note that it would've been
* better to use Option of Seq[DataType] so we can use "None" as the case for no
* type coercion. However, that would require more refactoring of the codebase.
* @param udfName The user-specified name of this UDF.
*/
case class ScalaUDF(
function: AnyRef,
dataType: DataType,
children: Seq[Expression],
inputTypes: Seq[DataType] = Nil)
inputTypes: Seq[DataType] = Nil,
udfName: Option[String] = None)
extends Expression with ImplicitCastInputTypes with NonSQLExpression {
override def nullable: Boolean = true
override def toString: String = s"UDF(${children.mkString(", ")})"
override def toString: String =
s"${udfName.map(name => s"UDF:$name").getOrElse("UDF")}(${children.mkString(", ")})"
// scalastyle:off line.size.limit
......
......@@ -17,6 +17,7 @@
package org.apache.spark.sql
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
......@@ -248,4 +249,17 @@ class UDFSuite extends QueryTest with SharedSQLContext {
sql("SELECT tmp.t.* FROM (SELECT testDataFunc(a, b) AS t from testData2) tmp").toDF(),
testData2)
}
test("SPARK-19338 Provide identical names for UDFs in the EXPLAIN output") {
def explainStr(df: DataFrame): String = {
val explain = ExplainCommand(df.queryExecution.logical, extended = false)
val sparkPlan = spark.sessionState.executePlan(explain).executedPlan
sparkPlan.executeCollect().map(_.getString(0).trim).headOption.getOrElse("")
}
val udf1 = "myUdf1"
val udf2 = "myUdf2"
spark.udf.register(udf1, (n: Int) => { n + 1 })
spark.udf.register(udf2, (n: Int) => { n * 1 })
assert(explainStr(sql("SELECT myUdf1(myUdf2(1))")).contains(s"UDF:$udf1(UDF:$udf2(1))"))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment