Skip to content
Snippets Groups Projects
Commit 8fcbda9c authored by Wang Gengliang's avatar Wang Gengliang Committed by gatorsmile
Browse files

[SPARK-21848][SQL] Add trait UserDefinedExpression to identify user-defined functions

## What changes were proposed in this pull request?

Add trait UserDefinedExpression to identify user-defined functions.
UDF can be expensive. In optimizer we may need to avoid executing UDF multiple times.
E.g.
```scala
table.select(UDF as 'a).select('a, ('a + 1) as 'b)
```
If UDF is expensive in this case, optimizer should not collapse the project to
```scala
table.select(UDF as 'a, (UDF+1) as 'b)
```

Currently UDF classes like PythonUDF, HiveGenericUDF are not defined in catalyst.
This PR is to add a new trait to make it easier to identify user-defined functions.

## How was this patch tested?

Unit test

Author: Wang Gengliang <ltnwgl@gmail.com>

Closes #19064 from gengliangwang/UDFType.
parent 32fa0b81
No related branches found
No related tags found
No related merge requests found
......@@ -635,3 +635,9 @@ abstract class TernaryExpression extends Expression {
}
}
}
/**
* Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages
* and Hive function wrappers.
*/
trait UserDefinedExpression
......@@ -47,7 +47,7 @@ case class ScalaUDF(
udfName: Option[String] = None,
nullable: Boolean = true,
udfDeterministic: Boolean = true)
extends Expression with ImplicitCastInputTypes with NonSQLExpression {
extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {
override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
......
......@@ -324,7 +324,11 @@ case class ScalaUDAF(
udaf: UserDefinedAggregateFunction,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends ImperativeAggregate with NonSQLExpression with Logging with ImplicitCastInputTypes {
extends ImperativeAggregate
with NonSQLExpression
with Logging
with ImplicitCastInputTypes
with UserDefinedExpression {
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
......
......@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.python
import org.apache.spark.api.python.PythonFunction
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable, UserDefinedExpression}
import org.apache.spark.sql.types.DataType
/**
......@@ -29,7 +29,7 @@ case class PythonUDF(
func: PythonFunction,
dataType: DataType,
children: Seq[Expression])
extends Expression with Unevaluable with NonSQLExpression {
extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {
override def toString: String = s"$name(${children.mkString(", ")})"
......
......@@ -42,7 +42,11 @@ import org.apache.spark.sql.types._
private[hive] case class HiveSimpleUDF(
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with CodegenFallback with Logging {
extends Expression
with HiveInspectors
with CodegenFallback
with Logging
with UserDefinedExpression {
override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic)
......@@ -119,7 +123,11 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp
private[hive] case class HiveGenericUDF(
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with CodegenFallback with Logging {
extends Expression
with HiveInspectors
with CodegenFallback
with Logging
with UserDefinedExpression {
override def nullable: Boolean = true
......@@ -191,7 +199,7 @@ private[hive] case class HiveGenericUDTF(
name: String,
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression])
extends Generator with HiveInspectors with CodegenFallback {
extends Generator with HiveInspectors with CodegenFallback with UserDefinedExpression {
@transient
protected lazy val function: GenericUDTF = {
......@@ -303,7 +311,9 @@ private[hive] case class HiveUDAFFunction(
isUDAFBridgeRequired: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors {
extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer]
with HiveInspectors
with UserDefinedExpression {
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment