From 8fcbda9c93175c0d44b0e4deaf10df1a427e03ea Mon Sep 17 00:00:00 2001 From: Wang Gengliang <ltnwgl@gmail.com> Date: Tue, 29 Aug 2017 09:08:59 -0700 Subject: [PATCH] [SPARK-21848][SQL] Add trait UserDefinedExpression to identify user-defined functions ## What changes were proposed in this pull request? Add trait UserDefinedExpression to identify user-defined functions. UDF can be expensive. In optimizer we may need to avoid executing UDF multiple times. E.g. ```scala table.select(UDF as 'a).select('a, ('a + 1) as 'b) ``` If UDF is expensive in this case, optimizer should not collapse the project to ```scala table.select(UDF as 'a, (UDF+1) as 'b) ``` Currently UDF classes like PythonUDF, HiveGenericUDF are not defined in catalyst. This PR is to add a new trait to make it easier to identify user-defined functions. ## How was this patch tested? Unit test Author: Wang Gengliang <ltnwgl@gmail.com> Closes #19064 from gengliangwang/UDFType. --- .../sql/catalyst/expressions/Expression.scala | 6 ++++++ .../sql/catalyst/expressions/ScalaUDF.scala | 2 +- .../spark/sql/execution/aggregate/udaf.scala | 6 +++++- .../spark/sql/execution/python/PythonUDF.scala | 4 ++-- .../org/apache/spark/sql/hive/hiveUDFs.scala | 18 ++++++++++++++---- 5 files changed, 28 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 74c4cddf2b..c058425b4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -635,3 +635,9 @@ abstract class TernaryExpression extends Expression { } } } + +/** + * Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages + * and Hive function wrappers. + */ +trait UserDefinedExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 9df0e2e141..527f1670c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -47,7 +47,7 @@ case class ScalaUDF( udfName: Option[String] = None, nullable: Boolean = true, udfDeterministic: Boolean = true) - extends Expression with ImplicitCastInputTypes with NonSQLExpression { + extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index ae5e2c6bec..fec1add18c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -324,7 +324,11 @@ case class ScalaUDAF( udaf: UserDefinedAggregateFunction, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with NonSQLExpression with Logging with ImplicitCastInputTypes { + extends ImperativeAggregate + with NonSQLExpression + with Logging + with ImplicitCastInputTypes + with UserDefinedExpression { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 59d7e8dd6d..7ebbdb9846 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable} +import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable, UserDefinedExpression} import org.apache.spark.sql.types.DataType /** @@ -29,7 +29,7 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression]) - extends Expression with Unevaluable with NonSQLExpression { + extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index a83ad61b20..e9bdcf00b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -42,7 +42,11 @@ import org.apache.spark.sql.types._ private[hive] case class HiveSimpleUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with CodegenFallback with Logging { + extends Expression + with HiveInspectors + with CodegenFallback + with Logging + with UserDefinedExpression { override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) @@ -119,7 +123,11 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp private[hive] case class HiveGenericUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with CodegenFallback with Logging { + extends Expression + with HiveInspectors + with CodegenFallback + with Logging + with UserDefinedExpression { override def nullable: Boolean = true @@ -191,7 +199,7 @@ private[hive] case class HiveGenericUDTF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Generator with HiveInspectors with CodegenFallback { + extends Generator with HiveInspectors with CodegenFallback with UserDefinedExpression { @transient protected lazy val function: GenericUDTF = { @@ -303,7 +311,9 @@ private[hive] case class HiveUDAFFunction( isUDAFBridgeRequired: Boolean = false, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors { + extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] + with HiveInspectors + with UserDefinedExpression { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) -- GitLab