diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 74c4cddf2b47e7b75c7288b59cc33731f01e8678..c058425b4bc36b978294060108bb141e7ecba75a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -635,3 +635,9 @@ abstract class TernaryExpression extends Expression { } } } + +/** + * Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages + * and Hive function wrappers. + */ +trait UserDefinedExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 9df0e2e1415c0033f6821677f7e7576d3077b6a7..527f1670c25e1bcf43e53a9dcfb424081817560e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -47,7 +47,7 @@ case class ScalaUDF( udfName: Option[String] = None, nullable: Boolean = true, udfDeterministic: Boolean = true) - extends Expression with ImplicitCastInputTypes with NonSQLExpression { + extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index ae5e2c6bece2a7bc1b32347fb5d7ad142506faa6..fec1add18cbf211464d4c426c934bd6a155e77f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -324,7 +324,11 @@ case class ScalaUDAF( udaf: UserDefinedAggregateFunction, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with NonSQLExpression with Logging with ImplicitCastInputTypes { + extends ImperativeAggregate + with NonSQLExpression + with Logging + with ImplicitCastInputTypes + with UserDefinedExpression { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 59d7e8dd6dffbd13e90091a95626d7bee9f74a73..7ebbdb9846cce094f9c8cc53c864a7d54090caea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable} +import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable, UserDefinedExpression} import org.apache.spark.sql.types.DataType /** @@ -29,7 +29,7 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression]) - extends Expression with Unevaluable with NonSQLExpression { + extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index a83ad61b204ada8eb5aefc187e75feca48096818..e9bdcf00b934625766d442be4cf9ab677aae695f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -42,7 +42,11 @@ import org.apache.spark.sql.types._ private[hive] case class HiveSimpleUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with CodegenFallback with Logging { + extends Expression + with HiveInspectors + with CodegenFallback + with Logging + with UserDefinedExpression { override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) @@ -119,7 +123,11 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp private[hive] case class HiveGenericUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with CodegenFallback with Logging { + extends Expression + with HiveInspectors + with CodegenFallback + with Logging + with UserDefinedExpression { override def nullable: Boolean = true @@ -191,7 +199,7 @@ private[hive] case class HiveGenericUDTF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Generator with HiveInspectors with CodegenFallback { + extends Generator with HiveInspectors with CodegenFallback with UserDefinedExpression { @transient protected lazy val function: GenericUDTF = { @@ -303,7 +311,9 @@ private[hive] case class HiveUDAFFunction( isUDAFBridgeRequired: Boolean = false, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors { + extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] + with HiveInspectors + with UserDefinedExpression { override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset)