Skip to content
Snippets Groups Projects
Commit f33d5504 authored by Venkata Ramana Gollamudi's avatar Venkata Ramana Gollamudi Committed by Michael Armbrust
Browse files

[SPARK-3891][SQL] Add array support to percentile, percentile_approx and...

[SPARK-3891][SQL] Add array support to percentile, percentile_approx and constant inspectors support

Supported passing array to percentile and percentile_approx UDAFs
To support percentile_approx,  constant inspectors are supported for GenericUDAF
Constant folding support added to CreateArray expression
Avoided constant udf expression re-evaluation

Author: Venkata Ramana G <ramana.gollamudihuawei.com>

Author: Venkata Ramana Gollamudi <ramana.gollamudi@huawei.com>

Closes #2802 from gvramana/percentile_array_support and squashes the following commits:

a0182e5 [Venkata Ramana Gollamudi] fixed review comment
a18f917 [Venkata Ramana Gollamudi] avoid constant udf expression re-evaluation - fixes failure due to return iterator and value type mismatch
c46db0f [Venkata Ramana Gollamudi] Removed TestHive reset
4d39105 [Venkata Ramana Gollamudi] Unified inspector creation, style check fixes
f37fd69 [Venkata Ramana Gollamudi] Fixed review comments
47f6365 [Venkata Ramana Gollamudi] fixed test
cb7c61e [Venkata Ramana Gollamudi] Supported ConstantInspector for UDAF Fixed HiveUdaf wrap object issue.
7f94aff [Venkata Ramana Gollamudi] Added foldable support to CreateArray
parent 8d0d2a65
No related branches found
No related tags found
No related merge requests found
...@@ -113,7 +113,9 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio ...@@ -113,7 +113,9 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio
*/ */
case class CreateArray(children: Seq[Expression]) extends Expression { case class CreateArray(children: Seq[Expression]) extends Expression {
override type EvaluatedType = Any override type EvaluatedType = Any
override def foldable = !children.exists(!_.foldable)
lazy val childTypes = children.map(_.dataType).distinct lazy val childTypes = children.map(_.dataType).distinct
override lazy val resolved = override lazy val resolved =
......
...@@ -158,6 +158,11 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr ...@@ -158,6 +158,11 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr
override def foldable = override def foldable =
isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector]
@transient
protected def constantReturnValue = unwrap(
returnInspector.asInstanceOf[ConstantObjectInspector].getWritableConstantValue(),
returnInspector)
@transient @transient
protected lazy val deferedObjects = protected lazy val deferedObjects =
argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject] argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject]
...@@ -166,6 +171,8 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr ...@@ -166,6 +171,8 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr
override def eval(input: Row): Any = { override def eval(input: Row): Any = {
returnInspector // Make sure initialized. returnInspector // Make sure initialized.
if(foldable) return constantReturnValue
var i = 0 var i = 0
while (i < children.length) { while (i < children.length) {
val idx = i val idx = i
...@@ -193,12 +200,13 @@ private[hive] case class HiveGenericUdaf( ...@@ -193,12 +200,13 @@ private[hive] case class HiveGenericUdaf(
@transient @transient
protected lazy val objectInspector = { protected lazy val objectInspector = {
resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray) val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
resolver.getEvaluator(parameterInfo)
.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
} }
@transient @transient
protected lazy val inspectors = children.map(_.dataType).map(toInspector) protected lazy val inspectors = children.map(toInspector)
def dataType: DataType = inspectorToDataType(objectInspector) def dataType: DataType = inspectorToDataType(objectInspector)
...@@ -223,12 +231,13 @@ private[hive] case class HiveUdaf( ...@@ -223,12 +231,13 @@ private[hive] case class HiveUdaf(
@transient @transient
protected lazy val objectInspector = { protected lazy val objectInspector = {
resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray) val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false)
resolver.getEvaluator(parameterInfo)
.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray)
} }
@transient @transient
protected lazy val inspectors = children.map(_.dataType).map(toInspector) protected lazy val inspectors = children.map(toInspector)
def dataType: DataType = inspectorToDataType(objectInspector) def dataType: DataType = inspectorToDataType(objectInspector)
...@@ -261,7 +270,7 @@ private[hive] case class HiveGenericUdtf( ...@@ -261,7 +270,7 @@ private[hive] case class HiveGenericUdtf(
protected lazy val function: GenericUDTF = funcWrapper.createFunction() protected lazy val function: GenericUDTF = funcWrapper.createFunction()
@transient @transient
protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) protected lazy val inputInspectors = children.map(toInspector)
@transient @transient
protected lazy val outputInspector = function.initialize(inputInspectors.toArray) protected lazy val outputInspector = function.initialize(inputInspectors.toArray)
...@@ -334,10 +343,13 @@ private[hive] case class HiveUdafFunction( ...@@ -334,10 +343,13 @@ private[hive] case class HiveUdafFunction(
} else { } else {
funcWrapper.createFunction[AbstractGenericUDAFResolver]() funcWrapper.createFunction[AbstractGenericUDAFResolver]()
} }
private val inspectors = exprs.map(_.dataType).map(toInspector).toArray private val inspectors = exprs.map(toInspector).toArray
private val function = resolver.getEvaluator(exprs.map(_.dataType.toTypeInfo).toArray) private val function = {
val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false)
resolver.getEvaluator(parameterInfo)
}
private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
...@@ -350,9 +362,12 @@ private[hive] case class HiveUdafFunction( ...@@ -350,9 +362,12 @@ private[hive] case class HiveUdafFunction(
@transient @transient
val inputProjection = new InterpretedProjection(exprs) val inputProjection = new InterpretedProjection(exprs)
@transient
protected lazy val cached = new Array[AnyRef](exprs.length)
def update(input: Row): Unit = { def update(input: Row): Unit = {
val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray
function.iterate(buffer, inputs) function.iterate(buffer, wrap(inputs, inspectors, cached))
} }
} }
...@@ -92,10 +92,21 @@ class HiveUdfSuite extends QueryTest { ...@@ -92,10 +92,21 @@ class HiveUdfSuite extends QueryTest {
} }
test("SPARK-2693 udaf aggregates test") { test("SPARK-2693 udaf aggregates test") {
checkAnswer(sql("SELECT percentile(key,1) FROM src LIMIT 1"), checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src").collect().toSeq) sql("SELECT max(key) FROM src").collect().toSeq)
checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"),
sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq)
} }
test("Generic UDAF aggregates") {
checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)
checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"),
sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq)
}
test("UDFIntegerToString") { test("UDFIntegerToString") {
val testData = TestHive.sparkContext.parallelize( val testData = TestHive.sparkContext.parallelize(
IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil) IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment