Skip to content
Snippets Groups Projects
Commit bc5f56aa authored by Yin Huai's avatar Yin Huai
Browse files

[SPARK-12250][SQL] Allow users to define a UDAF without providing details of its inputSchema

https://issues.apache.org/jira/browse/SPARK-12250

Author: Yin Huai <yhuai@databricks.com>

Closes #10236 from yhuai/SPARK-12250.
parent d9d354ed
No related branches found
No related tags found
No related merge requests found
...@@ -332,11 +332,6 @@ private[sql] case class ScalaUDAF( ...@@ -332,11 +332,6 @@ private[sql] case class ScalaUDAF(
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset) copy(inputAggBufferOffset = newInputAggBufferOffset)
require(
children.length == udaf.inputSchema.length,
s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
s"but ${children.length} are provided.")
override def nullable: Boolean = true override def nullable: Boolean = true
override def dataType: DataType = udaf.dataType override def dataType: DataType = udaf.dataType
......
...@@ -66,6 +66,33 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun ...@@ -66,6 +66,33 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun
} }
} }
class ScalaAggregateFunctionWithoutInputSchema extends UserDefinedAggregateFunction {
def inputSchema: StructType = StructType(Nil)
def bufferSchema: StructType = StructType(StructField("value", LongType) :: Nil)
def dataType: DataType = LongType
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0L)
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, input.getAs[Seq[Row]](0).map(_.getAs[Int]("v")).sum + buffer.getLong(0))
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
}
def evaluate(buffer: Row): Any = {
buffer.getLong(0)
}
}
class LongProductSum extends UserDefinedAggregateFunction { class LongProductSum extends UserDefinedAggregateFunction {
def inputSchema: StructType = new StructType() def inputSchema: StructType = new StructType()
.add("a", LongType) .add("a", LongType)
...@@ -858,6 +885,43 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te ...@@ -858,6 +885,43 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
) )
} }
} }
test("udaf without specifying inputSchema") {
withTempTable("noInputSchemaUDAF") {
sqlContext.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema)
val data =
Row(1, Seq(Row(1), Row(2), Row(3))) ::
Row(1, Seq(Row(4), Row(5), Row(6))) ::
Row(2, Seq(Row(-10))) :: Nil
val schema =
StructType(
StructField("key", IntegerType) ::
StructField("myArray",
ArrayType(StructType(StructField("v", IntegerType) :: Nil))) :: Nil)
sqlContext.createDataFrame(
sparkContext.parallelize(data, 2),
schema)
.registerTempTable("noInputSchemaUDAF")
checkAnswer(
sqlContext.sql(
"""
|SELECT key, noInputSchema(myArray)
|FROM noInputSchemaUDAF
|GROUP BY key
""".stripMargin),
Row(1, 21) :: Row(2, -10) :: Nil)
checkAnswer(
sqlContext.sql(
"""
|SELECT noInputSchema(myArray)
|FROM noInputSchemaUDAF
""".stripMargin),
Row(11) :: Nil)
}
}
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment