From 3e4f09d2fce9dcf45eaaca827f2cf15c9d4a6c75 Mon Sep 17 00:00:00 2001
From: Cheng Lian <lian.cs.zju@gmail.com>
Date: Wed, 8 Oct 2014 18:13:22 -0700
Subject: [PATCH] [SQL] Prevents per row dynamic dispatching and pattern
 matching when inserting Hive values

Builds all wrappers at first according to object inspector types to avoid per row costs.

Author: Cheng Lian <lian.cs.zju@gmail.com>

Closes #2592 from liancheng/hive-value-wrapper and squashes the following commits:

9696559 [Cheng Lian] Passes all tests
4998666 [Cheng Lian] Prevents per row dynamic dispatching and pattern matching when inserting Hive values
---
 .../hive/execution/InsertIntoHiveTable.scala  | 64 ++++++++++---------
 1 file changed, 34 insertions(+), 30 deletions(-)

diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index f8b4e898ec..f0785d8882 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -69,33 +69,36 @@ case class InsertIntoHiveTable(
    * Wraps with Hive types based on object inspector.
    * TODO: Consolidate all hive OI/data interface code.
    */
-  protected def wrap(a: (Any, ObjectInspector)): Any = a match {
-    case (s: String, oi: JavaHiveVarcharObjectInspector) =>
-      new HiveVarchar(s, s.size)
-
-    case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) =>
-      new HiveDecimal(bd.underlying())
-
-    case (row: Row, oi: StandardStructObjectInspector) =>
-      val struct = oi.create()
-      row.zip(oi.getAllStructFieldRefs: Seq[StructField]).foreach {
-        case (data, field) =>
-          oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector))
+  protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match {
+    case _: JavaHiveVarcharObjectInspector =>
+      (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size)
+
+    case _: JavaHiveDecimalObjectInspector =>
+      (o: Any) => new HiveDecimal(o.asInstanceOf[BigDecimal].underlying())
+
+    case soi: StandardStructObjectInspector =>
+      val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector))
+      (o: Any) => {
+        val struct = soi.create()
+        (soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach {
+          (field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data))
+        }
+        struct
       }
-      struct
 
-    case (s: Seq[_], oi: ListObjectInspector) =>
-      val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector))
-      seqAsJavaList(wrappedSeq)
+    case loi: ListObjectInspector =>
+      val wrapper = wrapperFor(loi.getListElementObjectInspector)
+      (o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper))
 
-    case (m: Map[_, _], oi: MapObjectInspector) =>
-      val keyOi = oi.getMapKeyObjectInspector
-      val valueOi = oi.getMapValueObjectInspector
-      val wrappedMap = m.map { case (key, value) => wrap(key, keyOi) -> wrap(value, valueOi) }
-      mapAsJavaMap(wrappedMap)
+    case moi: MapObjectInspector =>
+      val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector)
+      val valueWrapper = wrapperFor(moi.getMapValueObjectInspector)
+      (o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) =>
+        keyWrapper(key) -> valueWrapper(value)
+      })
 
-    case (obj, _) =>
-      obj
+    case _ =>
+      identity[Any]
   }
 
   def saveAsHiveFile(
@@ -103,7 +106,7 @@ case class InsertIntoHiveTable(
       valueClass: Class[_],
       fileSinkConf: FileSinkDesc,
       conf: SerializableWritable[JobConf],
-      writerContainer: SparkHiveWriterContainer) {
+      writerContainer: SparkHiveWriterContainer): Unit = {
     assert(valueClass != null, "Output value class not set")
     conf.value.setOutputValueClass(valueClass)
 
@@ -122,7 +125,7 @@ case class InsertIntoHiveTable(
     writerContainer.commitJob()
 
     // Note that this function is executed on executor side
-    def writeToFile(context: TaskContext, iterator: Iterator[Row]) {
+    def writeToFile(context: TaskContext, iterator: Iterator[Row]): Unit = {
       val serializer = newSerializer(fileSinkConf.getTableInfo)
       val standardOI = ObjectInspectorUtils
         .getStandardObjectInspector(
@@ -131,6 +134,7 @@ case class InsertIntoHiveTable(
         .asInstanceOf[StructObjectInspector]
 
       val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray
+      val wrappers = fieldOIs.map(wrapperFor)
       val outputData = new Array[Any](fieldOIs.length)
 
       // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
@@ -141,13 +145,13 @@ case class InsertIntoHiveTable(
       iterator.foreach { row =>
         var i = 0
         while (i < fieldOIs.length) {
-          // TODO (lian) avoid per row dynamic dispatching and pattern matching cost in `wrap`
-          outputData(i) = wrap(row(i), fieldOIs(i))
+          outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row(i))
           i += 1
         }
 
-        val writer = writerContainer.getLocalFileWriter(row)
-        writer.write(serializer.serialize(outputData, standardOI))
+        writerContainer
+          .getLocalFileWriter(row)
+          .write(serializer.serialize(outputData, standardOI))
       }
 
       writerContainer.close()
@@ -207,7 +211,7 @@ case class InsertIntoHiveTable(
 
       // Report error if any static partition appears after a dynamic partition
       val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty)
-      isDynamic.init.zip(isDynamic.tail).find(_ == (true, false)).foreach { _ =>
+      if (isDynamic.init.zip(isDynamic.tail).contains((true, false))) {
         throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg)
       }
     }
-- 
GitLab