Skip to content
Snippets Groups Projects
Commit 3e4f09d2 authored by Cheng Lian's avatar Cheng Lian Committed by Michael Armbrust
Browse files

[SQL] Prevents per row dynamic dispatching and pattern matching when inserting Hive values

Builds all wrappers at first according to object inspector types to avoid per row costs.

Author: Cheng Lian <lian.cs.zju@gmail.com>

Closes #2592 from liancheng/hive-value-wrapper and squashes the following commits:

9696559 [Cheng Lian] Passes all tests
4998666 [Cheng Lian] Prevents per row dynamic dispatching and pattern matching when inserting Hive values
parent e7033572
No related branches found
No related tags found
No related merge requests found
......@@ -69,33 +69,36 @@ case class InsertIntoHiveTable(
* Wraps with Hive types based on object inspector.
* TODO: Consolidate all hive OI/data interface code.
*/
protected def wrap(a: (Any, ObjectInspector)): Any = a match {
case (s: String, oi: JavaHiveVarcharObjectInspector) =>
new HiveVarchar(s, s.size)
case (bd: BigDecimal, oi: JavaHiveDecimalObjectInspector) =>
new HiveDecimal(bd.underlying())
case (row: Row, oi: StandardStructObjectInspector) =>
val struct = oi.create()
row.zip(oi.getAllStructFieldRefs: Seq[StructField]).foreach {
case (data, field) =>
oi.setStructFieldData(struct, field, wrap(data, field.getFieldObjectInspector))
protected def wrapperFor(oi: ObjectInspector): Any => Any = oi match {
case _: JavaHiveVarcharObjectInspector =>
(o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size)
case _: JavaHiveDecimalObjectInspector =>
(o: Any) => new HiveDecimal(o.asInstanceOf[BigDecimal].underlying())
case soi: StandardStructObjectInspector =>
val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector))
(o: Any) => {
val struct = soi.create()
(soi.getAllStructFieldRefs, wrappers, o.asInstanceOf[Row]).zipped.foreach {
(field, wrapper, data) => soi.setStructFieldData(struct, field, wrapper(data))
}
struct
}
struct
case (s: Seq[_], oi: ListObjectInspector) =>
val wrappedSeq = s.map(wrap(_, oi.getListElementObjectInspector))
seqAsJavaList(wrappedSeq)
case loi: ListObjectInspector =>
val wrapper = wrapperFor(loi.getListElementObjectInspector)
(o: Any) => seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper))
case (m: Map[_, _], oi: MapObjectInspector) =>
val keyOi = oi.getMapKeyObjectInspector
val valueOi = oi.getMapValueObjectInspector
val wrappedMap = m.map { case (key, value) => wrap(key, keyOi) -> wrap(value, valueOi) }
mapAsJavaMap(wrappedMap)
case moi: MapObjectInspector =>
val keyWrapper = wrapperFor(moi.getMapKeyObjectInspector)
val valueWrapper = wrapperFor(moi.getMapValueObjectInspector)
(o: Any) => mapAsJavaMap(o.asInstanceOf[Map[_, _]].map { case (key, value) =>
keyWrapper(key) -> valueWrapper(value)
})
case (obj, _) =>
obj
case _ =>
identity[Any]
}
def saveAsHiveFile(
......@@ -103,7 +106,7 @@ case class InsertIntoHiveTable(
valueClass: Class[_],
fileSinkConf: FileSinkDesc,
conf: SerializableWritable[JobConf],
writerContainer: SparkHiveWriterContainer) {
writerContainer: SparkHiveWriterContainer): Unit = {
assert(valueClass != null, "Output value class not set")
conf.value.setOutputValueClass(valueClass)
......@@ -122,7 +125,7 @@ case class InsertIntoHiveTable(
writerContainer.commitJob()
// Note that this function is executed on executor side
def writeToFile(context: TaskContext, iterator: Iterator[Row]) {
def writeToFile(context: TaskContext, iterator: Iterator[Row]): Unit = {
val serializer = newSerializer(fileSinkConf.getTableInfo)
val standardOI = ObjectInspectorUtils
.getStandardObjectInspector(
......@@ -131,6 +134,7 @@ case class InsertIntoHiveTable(
.asInstanceOf[StructObjectInspector]
val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray
val wrappers = fieldOIs.map(wrapperFor)
val outputData = new Array[Any](fieldOIs.length)
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
......@@ -141,13 +145,13 @@ case class InsertIntoHiveTable(
iterator.foreach { row =>
var i = 0
while (i < fieldOIs.length) {
// TODO (lian) avoid per row dynamic dispatching and pattern matching cost in `wrap`
outputData(i) = wrap(row(i), fieldOIs(i))
outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row(i))
i += 1
}
val writer = writerContainer.getLocalFileWriter(row)
writer.write(serializer.serialize(outputData, standardOI))
writerContainer
.getLocalFileWriter(row)
.write(serializer.serialize(outputData, standardOI))
}
writerContainer.close()
......@@ -207,7 +211,7 @@ case class InsertIntoHiveTable(
// Report error if any static partition appears after a dynamic partition
val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty)
isDynamic.init.zip(isDynamic.tail).find(_ == (true, false)).foreach { _ =>
if (isDynamic.init.zip(isDynamic.tail).contains((true, false))) {
throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg)
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment