Skip to content
Snippets Groups Projects
Commit aad644fb authored by Yin Huai's avatar Yin Huai Committed by Michael Armbrust
Browse files

[SPARK-10639] [SQL] Need to convert UDAF's result from scala to sql type

https://issues.apache.org/jira/browse/SPARK-10639

Author: Yin Huai <yhuai@databricks.com>

Closes #8788 from yhuai/udafConversion.
parent e0dc2bc2
No related branches found
No related tags found
No related merge requests found
......@@ -138,8 +138,13 @@ object CatalystTypeConverters {
private case class UDTConverter(
udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] {
// toCatalyst (it calls toCatalystImpl) will do null check.
override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue)
override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue)
override def toScala(catalystValue: Any): Any = {
if (catalystValue == null) null else udt.deserialize(catalystValue)
}
override def toScalaImpl(row: InternalRow, column: Int): Any =
toScala(row.get(column, udt.sqlType))
}
......
......@@ -108,7 +108,21 @@ object RandomDataGenerator {
arr
})
case BooleanType => Some(() => rand.nextBoolean())
case DateType => Some(() => new java.sql.Date(rand.nextInt()))
case DateType =>
val generator =
() => {
var milliseconds = rand.nextLong() % 253402329599999L
// -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT
// for "0001-01-01 00:00:00.000000". We need to find a
// number that is greater or equals to this number as a valid timestamp value.
while (milliseconds < -62135740800000L) {
// 253402329599999L is the the number of milliseconds since
// January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999".
milliseconds = rand.nextLong() % 253402329599999L
}
DateTimeUtils.toJavaDate((milliseconds / DateTimeUtils.MILLIS_PER_DAY).toInt)
}
Some(generator)
case TimestampType =>
val generator =
() => {
......
......@@ -40,6 +40,9 @@ sealed trait BufferSetterGetterUtils {
var i = 0
while (i < getters.length) {
getters(i) = dataTypes(i) match {
case NullType =>
(row: InternalRow, ordinal: Int) => null
case BooleanType =>
(row: InternalRow, ordinal: Int) =>
if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal)
......@@ -74,6 +77,14 @@ sealed trait BufferSetterGetterUtils {
(row: InternalRow, ordinal: Int) =>
if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale)
case DateType =>
(row: InternalRow, ordinal: Int) =>
if (row.isNullAt(ordinal)) null else row.getInt(ordinal)
case TimestampType =>
(row: InternalRow, ordinal: Int) =>
if (row.isNullAt(ordinal)) null else row.getLong(ordinal)
case other =>
(row: InternalRow, ordinal: Int) =>
if (row.isNullAt(ordinal)) null else row.get(ordinal, other)
......@@ -92,6 +103,9 @@ sealed trait BufferSetterGetterUtils {
var i = 0
while (i < setters.length) {
setters(i) = dataTypes(i) match {
case NullType =>
(row: MutableRow, ordinal: Int, value: Any) => row.setNullAt(ordinal)
case b: BooleanType =>
(row: MutableRow, ordinal: Int, value: Any) =>
if (value != null) {
......@@ -150,9 +164,23 @@ sealed trait BufferSetterGetterUtils {
case dt: DecimalType =>
val precision = dt.precision
(row: MutableRow, ordinal: Int, value: Any) =>
// To make it work with UnsafeRow, we cannot use setNullAt.
// Please see the comment of UnsafeRow's setDecimal.
row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision)
case DateType =>
(row: MutableRow, ordinal: Int, value: Any) =>
if (value != null) {
row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision)
row.setInt(ordinal, value.asInstanceOf[Int])
} else {
row.setNullAt(ordinal)
}
case TimestampType =>
(row: MutableRow, ordinal: Int, value: Any) =>
if (value != null) {
row.setLong(ordinal, value.asInstanceOf[Long])
} else {
row.setNullAt(ordinal)
}
......@@ -205,6 +233,7 @@ private[sql] class MutableAggregationBufferImpl (
throw new IllegalArgumentException(
s"Could not access ${i}th value in this buffer because it only has $length values.")
}
toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i)))
}
......@@ -352,6 +381,10 @@ private[sql] case class ScalaUDAF(
}
}
private[this] lazy val outputToCatalystConverter: Any => Any = {
CatalystTypeConverters.createToCatalystConverter(dataType)
}
// This buffer is only used at executor side.
private[this] var inputAggregateBuffer: InputAggregationBuffer = null
......@@ -424,7 +457,7 @@ private[sql] case class ScalaUDAF(
override def eval(buffer: InternalRow): Any = {
evalAggregateBuffer.underlyingInputBuffer = buffer
udaf.evaluate(evalAggregateBuffer)
outputToCatalystConverter(udaf.evaluate(evalAggregateBuffer))
}
override def toString: String = {
......
......@@ -115,19 +115,26 @@ object QueryTest {
*/
def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
// We need to call prepareRow recursively to handle schemas with struct types.
def prepareRow(row: Row): Row = {
Row.fromSeq(row.toSeq.map {
case null => null
case d: java.math.BigDecimal => BigDecimal(d)
// Convert array to Seq for easy equality check.
case b: Array[_] => b.toSeq
case r: Row => prepareRow(r)
case o => o
})
}
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
// Java's java.math.BigDecimal.compareTo).
// For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
// equality test.
val converted: Seq[Row] = answer.map { s =>
Row.fromSeq(s.toSeq.map {
case d: java.math.BigDecimal => BigDecimal(d)
case b: Array[Byte] => b.toSeq
case o => o
})
}
val converted: Seq[Row] = answer.map(prepareRow)
if (!isSorted) converted.sortBy(_.toString()) else converted
}
val sparkAnswer = try df.collect().toSeq catch {
......
......@@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty}
import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
......@@ -163,4 +164,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext {
assert(new MyDenseVectorUDT().typeName === "mydensevector")
assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset")
}
test("Catalyst type converter null handling for UDTs") {
val udt = new MyDenseVectorUDT()
val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt)
assert(toScalaConverter(null) === null)
val toCatalystConverter = CatalystTypeConverters.createToCatalystConverter(udt)
assert(toCatalystConverter(null) === null)
}
}
......@@ -17,13 +17,55 @@
package org.apache.spark.sql.hive.execution
import scala.collection.JavaConverters._
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
import org.apache.spark.sql.hive.test.TestHiveSingleton
class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {
def inputSchema: StructType = schema
def bufferSchema: StructType = schema
def dataType: DataType = schema
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
(0 until schema.length).foreach { i =>
buffer.update(i, null)
}
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0) && input.getInt(0) == 50) {
(0 until schema.length).foreach { i =>
buffer.update(i, input.get(i))
}
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) {
(0 until schema.length).foreach { i =>
buffer1.update(i, buffer2.get(i))
}
}
}
def evaluate(buffer: Row): Any = {
Row.fromSeq(buffer.toSeq)
}
}
abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._
......@@ -508,6 +550,70 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
}
}
test("udaf with all data types") {
val struct =
StructType(
StructField("f1", FloatType, true) ::
StructField("f2", ArrayType(BooleanType), true) :: Nil)
val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
DateType, TimestampType,
ArrayType(IntegerType), MapType(StringType, LongType), struct,
new MyDenseVectorUDT())
// Right now, we will use SortBasedAggregate to handle UDAFs.
// UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use
// UnsafeRow as the aggregation buffer. While, dataTypes will trigger
// SortBasedAggregate to use a safe row as the aggregation buffer.
Seq(dataTypes, UnsafeRow.mutableFieldTypes.asScala.toSeq).foreach { dataTypes =>
val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
StructField(s"col$index", dataType, nullable = true)
}
// The schema used for data generator.
val schemaForGenerator = StructType(fields)
// The schema used for the DataFrame df.
val schema = StructType(StructField("id", IntegerType) +: fields)
logInfo(s"Testing schema: ${schema.treeString}")
val udaf = new ScalaAggregateFunction(schema)
// Generate data at the driver side. We need to materialize the data first and then
// create RDD.
val maybeDataGenerator =
RandomDataGenerator.forType(
dataType = schemaForGenerator,
nullable = true,
seed = Some(System.nanoTime()))
val dataGenerator =
maybeDataGenerator
.getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator"))
val data = (1 to 50).map { i =>
dataGenerator.apply() match {
case row: Row => Row.fromSeq(i +: row.toSeq)
case null => Row.fromSeq(i +: Seq.fill(schemaForGenerator.length)(null))
case other =>
fail(s"Row or null is expected to be generated, " +
s"but a ${other.getClass.getCanonicalName} is generated.")
}
}
// Create a DF for the schema with random data.
val rdd = sqlContext.sparkContext.parallelize(data, 1)
val df = sqlContext.createDataFrame(rdd, schema)
val allColumns = df.schema.fields.map(f => col(f.name))
val expectedAnaswer =
data
.find(r => r.getInt(0) == 50)
.getOrElse(fail("A row with id 50 should be the expected answer."))
checkAnswer(
df.groupBy().agg(udaf(allColumns: _*)),
// udaf returns a Row as the output value.
Row(expectedAnaswer)
)
}
}
}
class SortBasedAggregationQuerySuite extends AggregationQuerySuite {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment