From aad644fbe29151aec9004817d42e4928bdb326f3 Mon Sep 17 00:00:00 2001
From: Yin Huai <yhuai@databricks.com>
Date: Thu, 17 Sep 2015 11:14:52 -0700
Subject: [PATCH] [SPARK-10639] [SQL] Need to convert UDAF's result from scala
 to sql type

https://issues.apache.org/jira/browse/SPARK-10639

Author: Yin Huai <yhuai@databricks.com>

Closes #8788 from yhuai/udafConversion.
---
 .../sql/catalyst/CatalystTypeConverters.scala |   7 +-
 .../spark/sql/RandomDataGenerator.scala       |  16 ++-
 .../spark/sql/execution/aggregate/udaf.scala  |  37 +++++-
 .../org/apache/spark/sql/QueryTest.scala      |  21 ++--
 .../spark/sql/UserDefinedTypeSuite.scala      |  11 ++
 .../execution/AggregationQuerySuite.scala     | 108 +++++++++++++++++-
 6 files changed, 188 insertions(+), 12 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 966623ed01..f25591794a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -138,8 +138,13 @@ object CatalystTypeConverters {
 
   private case class UDTConverter(
       udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] {
+    // toCatalyst (it calls toCatalystImpl) will do null check.
     override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue)
-    override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue)
+
+    override def toScala(catalystValue: Any): Any = {
+      if (catalystValue == null) null else udt.deserialize(catalystValue)
+    }
+
     override def toScalaImpl(row: InternalRow, column: Int): Any =
       toScala(row.get(column, udt.sqlType))
   }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index 4025cbcec1..e48395028e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -108,7 +108,21 @@ object RandomDataGenerator {
         arr
       })
       case BooleanType => Some(() => rand.nextBoolean())
-      case DateType => Some(() => new java.sql.Date(rand.nextInt()))
+      case DateType =>
+        val generator =
+          () => {
+            var milliseconds = rand.nextLong() % 253402329599999L
+            // -62135740800000L is the number of milliseconds before January 1, 1970, 00:00:00 GMT
+            // for "0001-01-01 00:00:00.000000". We need to find a
+            // number that is greater or equals to this number as a valid timestamp value.
+            while (milliseconds < -62135740800000L) {
+              // 253402329599999L is the the number of milliseconds since
+              // January 1, 1970, 00:00:00 GMT for "9999-12-31 23:59:59.999999".
+              milliseconds = rand.nextLong() % 253402329599999L
+            }
+            DateTimeUtils.toJavaDate((milliseconds / DateTimeUtils.MILLIS_PER_DAY).toInt)
+          }
+        Some(generator)
       case TimestampType =>
         val generator =
           () => {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index d43d3dd9ff..1114fe6552 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -40,6 +40,9 @@ sealed trait BufferSetterGetterUtils {
     var i = 0
     while (i < getters.length) {
       getters(i) = dataTypes(i) match {
+        case NullType =>
+          (row: InternalRow, ordinal: Int) => null
+
         case BooleanType =>
           (row: InternalRow, ordinal: Int) =>
             if (row.isNullAt(ordinal)) null else row.getBoolean(ordinal)
@@ -74,6 +77,14 @@ sealed trait BufferSetterGetterUtils {
           (row: InternalRow, ordinal: Int) =>
             if (row.isNullAt(ordinal)) null else row.getDecimal(ordinal, precision, scale)
 
+        case DateType =>
+          (row: InternalRow, ordinal: Int) =>
+            if (row.isNullAt(ordinal)) null else row.getInt(ordinal)
+
+        case TimestampType =>
+          (row: InternalRow, ordinal: Int) =>
+            if (row.isNullAt(ordinal)) null else row.getLong(ordinal)
+
         case other =>
           (row: InternalRow, ordinal: Int) =>
             if (row.isNullAt(ordinal)) null else row.get(ordinal, other)
@@ -92,6 +103,9 @@ sealed trait BufferSetterGetterUtils {
     var i = 0
     while (i < setters.length) {
       setters(i) = dataTypes(i) match {
+        case NullType =>
+          (row: MutableRow, ordinal: Int, value: Any) => row.setNullAt(ordinal)
+
         case b: BooleanType =>
           (row: MutableRow, ordinal: Int, value: Any) =>
             if (value != null) {
@@ -150,9 +164,23 @@ sealed trait BufferSetterGetterUtils {
 
         case dt: DecimalType =>
           val precision = dt.precision
+          (row: MutableRow, ordinal: Int, value: Any) =>
+            // To make it work with UnsafeRow, we cannot use setNullAt.
+            // Please see the comment of UnsafeRow's setDecimal.
+            row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision)
+
+        case DateType =>
           (row: MutableRow, ordinal: Int, value: Any) =>
             if (value != null) {
-              row.setDecimal(ordinal, value.asInstanceOf[Decimal], precision)
+              row.setInt(ordinal, value.asInstanceOf[Int])
+            } else {
+              row.setNullAt(ordinal)
+            }
+
+        case TimestampType =>
+          (row: MutableRow, ordinal: Int, value: Any) =>
+            if (value != null) {
+              row.setLong(ordinal, value.asInstanceOf[Long])
             } else {
               row.setNullAt(ordinal)
             }
@@ -205,6 +233,7 @@ private[sql] class MutableAggregationBufferImpl (
       throw new IllegalArgumentException(
         s"Could not access ${i}th value in this buffer because it only has $length values.")
     }
+
     toScalaConverters(i)(bufferValueGetters(i)(underlyingBuffer, offsets(i)))
   }
 
@@ -352,6 +381,10 @@ private[sql] case class ScalaUDAF(
     }
   }
 
+  private[this] lazy val outputToCatalystConverter: Any => Any = {
+    CatalystTypeConverters.createToCatalystConverter(dataType)
+  }
+
   // This buffer is only used at executor side.
   private[this] var inputAggregateBuffer: InputAggregationBuffer = null
 
@@ -424,7 +457,7 @@ private[sql] case class ScalaUDAF(
   override def eval(buffer: InternalRow): Any = {
     evalAggregateBuffer.underlyingInputBuffer = buffer
 
-    udaf.evaluate(evalAggregateBuffer)
+    outputToCatalystConverter(udaf.evaluate(evalAggregateBuffer))
   }
 
   override def toString: String = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index cada03e9ac..e3c5a42667 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -115,19 +115,26 @@ object QueryTest {
    */
   def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
     val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
+
+    // We need to call prepareRow recursively to handle schemas with struct types.
+    def prepareRow(row: Row): Row = {
+      Row.fromSeq(row.toSeq.map {
+        case null => null
+        case d: java.math.BigDecimal => BigDecimal(d)
+        // Convert array to Seq for easy equality check.
+        case b: Array[_] => b.toSeq
+        case r: Row => prepareRow(r)
+        case o => o
+      })
+    }
+
     def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
       // Converts data to types that we can do equality comparison using Scala collections.
       // For BigDecimal type, the Scala type has a better definition of equality test (similar to
       // Java's java.math.BigDecimal.compareTo).
       // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
       // equality test.
-      val converted: Seq[Row] = answer.map { s =>
-        Row.fromSeq(s.toSeq.map {
-          case d: java.math.BigDecimal => BigDecimal(d)
-          case b: Array[Byte] => b.toSeq
-          case o => o
-        })
-      }
+      val converted: Seq[Row] = answer.map(prepareRow)
       if (!isSorted) converted.sortBy(_.toString()) else converted
     }
     val sparkAnswer = try df.collect().toSeq catch {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 46d87843df..7992fd59ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -22,6 +22,7 @@ import scala.beans.{BeanInfo, BeanProperty}
 import com.clearspring.analytics.stream.cardinality.HyperLogLog
 
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
 import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
@@ -163,4 +164,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext {
     assert(new MyDenseVectorUDT().typeName === "mydensevector")
     assert(new OpenHashSetUDT(IntegerType).typeName === "openhashset")
   }
+
+  test("Catalyst type converter null handling for UDTs") {
+    val udt = new MyDenseVectorUDT()
+    val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt)
+    assert(toScalaConverter(null) === null)
+
+    val toCatalystConverter = CatalystTypeConverters.createToCatalystConverter(udt)
+    assert(toCatalystConverter(null) === null)
+
+  }
 }
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index a73b1bd52c..24b1846923 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -17,13 +17,55 @@
 
 package org.apache.spark.sql.hive.execution
 
+import scala.collection.JavaConverters._
+
 import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.execution.aggregate
+import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 
+class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {
+
+  def inputSchema: StructType = schema
+
+  def bufferSchema: StructType = schema
+
+  def dataType: DataType = schema
+
+  def deterministic: Boolean = true
+
+  def initialize(buffer: MutableAggregationBuffer): Unit = {
+    (0 until schema.length).foreach { i =>
+      buffer.update(i, null)
+    }
+  }
+
+  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+    if (!input.isNullAt(0) && input.getInt(0) == 50) {
+      (0 until schema.length).foreach { i =>
+        buffer.update(i, input.get(i))
+      }
+    }
+  }
+
+  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+    if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) {
+      (0 until schema.length).foreach { i =>
+        buffer1.update(i, buffer2.get(i))
+      }
+    }
+  }
+
+  def evaluate(buffer: Row): Any = {
+    Row.fromSeq(buffer.toSeq)
+  }
+}
+
 abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
   import testImplicits._
 
@@ -508,6 +550,70 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
       assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
     }
   }
+
+  test("udaf with all data types") {
+    val struct =
+      StructType(
+        StructField("f1", FloatType, true) ::
+          StructField("f2", ArrayType(BooleanType), true) :: Nil)
+    val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType,
+      ByteType, ShortType, IntegerType, LongType,
+      FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
+      DateType, TimestampType,
+      ArrayType(IntegerType), MapType(StringType, LongType), struct,
+      new MyDenseVectorUDT())
+    // Right now, we will use SortBasedAggregate to handle UDAFs.
+    // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use
+    // UnsafeRow as the aggregation buffer. While, dataTypes will trigger
+    // SortBasedAggregate to use a safe row as the aggregation buffer.
+    Seq(dataTypes, UnsafeRow.mutableFieldTypes.asScala.toSeq).foreach { dataTypes =>
+      val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
+        StructField(s"col$index", dataType, nullable = true)
+      }
+      // The schema used for data generator.
+      val schemaForGenerator = StructType(fields)
+      // The schema used for the DataFrame df.
+      val schema = StructType(StructField("id", IntegerType) +: fields)
+
+      logInfo(s"Testing schema: ${schema.treeString}")
+
+      val udaf = new ScalaAggregateFunction(schema)
+      // Generate data at the driver side. We need to materialize the data first and then
+      // create RDD.
+      val maybeDataGenerator =
+        RandomDataGenerator.forType(
+          dataType = schemaForGenerator,
+          nullable = true,
+          seed = Some(System.nanoTime()))
+      val dataGenerator =
+        maybeDataGenerator
+          .getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator"))
+      val data = (1 to 50).map { i =>
+        dataGenerator.apply() match {
+          case row: Row => Row.fromSeq(i +: row.toSeq)
+          case null => Row.fromSeq(i +: Seq.fill(schemaForGenerator.length)(null))
+          case other =>
+            fail(s"Row or null is expected to be generated, " +
+              s"but a ${other.getClass.getCanonicalName} is generated.")
+        }
+      }
+
+      // Create a DF for the schema with random data.
+      val rdd = sqlContext.sparkContext.parallelize(data, 1)
+      val df = sqlContext.createDataFrame(rdd, schema)
+
+      val allColumns = df.schema.fields.map(f => col(f.name))
+      val expectedAnaswer =
+        data
+          .find(r => r.getInt(0) == 50)
+          .getOrElse(fail("A row with id 50 should be the expected answer."))
+      checkAnswer(
+        df.groupBy().agg(udaf(allColumns: _*)),
+        // udaf returns a Row as the output value.
+        Row(expectedAnaswer)
+      )
+    }
+  }
 }
 
 class SortBasedAggregationQuerySuite extends AggregationQuerySuite {
-- 
GitLab