Skip to content
Snippets Groups Projects
Commit 1871574a authored by Michael Armbrust's avatar Michael Armbrust
Browse files

[SPARK-2569][SQL] Fix shipping of TEMPORARY hive UDFs.

Instead of shipping just the name and then looking up the info on the workers, we now ship the whole classname.  Also, I refactored the file as it was getting pretty large to move out the type conversion code to its own file.

Author: Michael Armbrust <michael@databricks.com>

Closes #1552 from marmbrus/fixTempUdfs and squashes the following commits:

b695904 [Michael Armbrust] Make add jar execute with Hive.  Ship the whole function class name since sometimes we cannot lookup temporary functions on the workers.
parent e060d3ee
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive
import org.apache.hadoop.hive.common.`type`.HiveDecimal
import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types
import org.apache.spark.sql.catalyst.types._
/* Implicit conversions */
import scala.collection.JavaConversions._
private[hive] trait HiveInspectors {
def javaClassToDataType(clz: Class[_]): DataType = clz match {
// writable
case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType
case c: Class[_] if c == classOf[hadoopIo.Text] => StringType
case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType
case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType
case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType
case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType
case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType
// java class
case c: Class[_] if c == classOf[java.lang.String] => StringType
case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
case c: Class[_] if c == classOf[HiveDecimal] => DecimalType
case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType
case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
case c: Class[_] if c == classOf[java.lang.Short] => ShortType
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
case c: Class[_] if c == classOf[java.lang.Long] => LongType
case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
case c: Class[_] if c == classOf[java.lang.Float] => FloatType
case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
// primitive type
case c: Class[_] if c == java.lang.Short.TYPE => ShortType
case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
case c: Class[_] if c == java.lang.Long.TYPE => LongType
case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
case c: Class[_] if c == java.lang.Float.TYPE => FloatType
case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType))
}
/** Converts hive types to native catalyst types. */
def unwrap(a: Any): Any = a match {
case null => null
case i: hadoopIo.IntWritable => i.get
case t: hadoopIo.Text => t.toString
case l: hadoopIo.LongWritable => l.get
case d: hadoopIo.DoubleWritable => d.get
case d: hiveIo.DoubleWritable => d.get
case s: hiveIo.ShortWritable => s.get
case b: hadoopIo.BooleanWritable => b.get
case b: hiveIo.ByteWritable => b.get
case b: hadoopIo.FloatWritable => b.get
case b: hadoopIo.BytesWritable => {
val bytes = new Array[Byte](b.getLength)
System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength)
bytes
}
case t: hiveIo.TimestampWritable => t.getTimestamp
case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue())
case list: java.util.List[_] => list.map(unwrap)
case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap
case array: Array[_] => array.map(unwrap).toSeq
case p: java.lang.Short => p
case p: java.lang.Long => p
case p: java.lang.Float => p
case p: java.lang.Integer => p
case p: java.lang.Double => p
case p: java.lang.Byte => p
case p: java.lang.Boolean => p
case str: String => str
case p: java.math.BigDecimal => p
case p: Array[Byte] => p
case p: java.sql.Timestamp => p
}
def unwrapData(data: Any, oi: ObjectInspector): Any = oi match {
case hvoi: HiveVarcharObjectInspector =>
if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue
case hdoi: HiveDecimalObjectInspector =>
if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data)
case li: ListObjectInspector =>
Option(li.getList(data))
.map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq)
.orNull
case mi: MapObjectInspector =>
Option(mi.getMap(data)).map(
_.map {
case (k,v) =>
(unwrapData(k, mi.getMapKeyObjectInspector),
unwrapData(v, mi.getMapValueObjectInspector))
}.toMap).orNull
case si: StructObjectInspector =>
val allRefs = si.getAllStructFieldRefs
new GenericRow(
allRefs.map(r =>
unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray)
}
/** Converts native catalyst types to the types expected by Hive */
def wrap(a: Any): AnyRef = a match {
case s: String => new hadoopIo.Text(s) // TODO why should be Text?
case i: Int => i: java.lang.Integer
case b: Boolean => b: java.lang.Boolean
case f: Float => f: java.lang.Float
case d: Double => d: java.lang.Double
case l: Long => l: java.lang.Long
case l: Short => l: java.lang.Short
case l: Byte => l: java.lang.Byte
case b: BigDecimal => b.bigDecimal
case b: Array[Byte] => b
case t: java.sql.Timestamp => t
case s: Seq[_] => seqAsJavaList(s.map(wrap))
case m: Map[_,_] =>
mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) })
case null => null
}
def toInspector(dataType: DataType): ObjectInspector = dataType match {
case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
case MapType(keyType, valueType) =>
ObjectInspectorFactory.getStandardMapObjectInspector(
toInspector(keyType), toInspector(valueType))
case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector
case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector
case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector
case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector
case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector
case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector
case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector
case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector
case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector
case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
case StructType(fields) =>
ObjectInspectorFactory.getStandardStructObjectInspector(
fields.map(f => f.name), fields.map(f => toInspector(f.dataType)))
}
def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
case s: StructObjectInspector =>
StructType(s.getAllStructFieldRefs.map(f => {
types.StructField(
f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true)
}))
case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector))
case m: MapObjectInspector =>
MapType(
inspectorToDataType(m.getMapKeyObjectInspector),
inspectorToDataType(m.getMapValueObjectInspector))
case _: WritableStringObjectInspector => StringType
case _: JavaStringObjectInspector => StringType
case _: WritableIntObjectInspector => IntegerType
case _: JavaIntObjectInspector => IntegerType
case _: WritableDoubleObjectInspector => DoubleType
case _: JavaDoubleObjectInspector => DoubleType
case _: WritableBooleanObjectInspector => BooleanType
case _: JavaBooleanObjectInspector => BooleanType
case _: WritableLongObjectInspector => LongType
case _: JavaLongObjectInspector => LongType
case _: WritableShortObjectInspector => ShortType
case _: JavaShortObjectInspector => ShortType
case _: WritableByteObjectInspector => ByteType
case _: JavaByteObjectInspector => ByteType
case _: WritableFloatObjectInspector => FloatType
case _: JavaFloatObjectInspector => FloatType
case _: WritableBinaryObjectInspector => BinaryType
case _: JavaBinaryObjectInspector => BinaryType
case _: WritableHiveDecimalObjectInspector => DecimalType
case _: JavaHiveDecimalObjectInspector => DecimalType
case _: WritableTimestampObjectInspector => TimestampType
case _: JavaTimestampObjectInspector => TimestampType
}
implicit class typeInfoConversions(dt: DataType) {
import org.apache.hadoop.hive.serde2.typeinfo._
import TypeInfoFactory._
def toTypeInfo: TypeInfo = dt match {
case BinaryType => binaryTypeInfo
case BooleanType => booleanTypeInfo
case ByteType => byteTypeInfo
case DoubleType => doubleTypeInfo
case FloatType => floatTypeInfo
case IntegerType => intTypeInfo
case LongType => longTypeInfo
case ShortType => shortTypeInfo
case StringType => stringTypeInfo
case DecimalType => decimalTypeInfo
case TimestampType => timestampTypeInfo
case NullType => voidTypeInfo
}
}
}
......@@ -42,8 +42,6 @@ private[hive] case class ShellCommand(cmd: String) extends Command
private[hive] case class SourceCommand(filePath: String) extends Command
private[hive] case class AddJar(jarPath: String) extends Command
private[hive] case class AddFile(filePath: String) extends Command
/** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */
......@@ -229,7 +227,7 @@ private[hive] object HiveQl {
} else if (sql.trim.toLowerCase.startsWith("uncache table")) {
CacheCommand(sql.trim.drop(14).trim, false)
} else if (sql.trim.toLowerCase.startsWith("add jar")) {
AddJar(sql.trim.drop(8))
NativeCommand(sql)
} else if (sql.trim.toLowerCase.startsWith("add file")) {
AddFile(sql.trim.drop(9))
} else if (sql.trim.toLowerCase.startsWith("dfs")) {
......
......@@ -24,22 +24,19 @@ import org.apache.hadoop.hive.ql.exec.UDF
import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
import org.apache.hadoop.hive.ql.udf.generic._
import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}
import org.apache.spark.sql.Logging
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.util.Utils.getContextOrSparkClassLoader
/* Implicit conversions */
import scala.collection.JavaConversions._
private[hive] object HiveFunctionRegistry
extends analysis.FunctionRegistry with HiveFunctionFactory with HiveInspectors {
private[hive] object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors {
def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name)
def lookupFunction(name: String, children: Seq[Expression]): Expression = {
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
......@@ -47,111 +44,37 @@ private[hive] object HiveFunctionRegistry
val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse(
sys.error(s"Couldn't find function $name"))
val functionClassName = functionInfo.getFunctionClass.getName()
if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
val function = createFunction[UDF](name)
val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF]
val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType)
HiveSimpleUdf(
name,
functionClassName,
children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) }
)
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdf(name, children)
HiveGenericUdf(functionClassName, children)
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdaf(name, children)
HiveGenericUdaf(functionClassName, children)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdtf(name, Nil, children)
HiveGenericUdtf(functionClassName, Nil, children)
} else {
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
}
}
def javaClassToDataType(clz: Class[_]): DataType = clz match {
// writable
case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType
case c: Class[_] if c == classOf[hadoopIo.Text] => StringType
case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType
case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType
case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType
case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType
case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType
// java class
case c: Class[_] if c == classOf[java.lang.String] => StringType
case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
case c: Class[_] if c == classOf[HiveDecimal] => DecimalType
case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType
case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
case c: Class[_] if c == classOf[java.lang.Short] => ShortType
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
case c: Class[_] if c == classOf[java.lang.Long] => LongType
case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
case c: Class[_] if c == classOf[java.lang.Float] => FloatType
case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
// primitive type
case c: Class[_] if c == java.lang.Short.TYPE => ShortType
case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
case c: Class[_] if c == java.lang.Long.TYPE => LongType
case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
case c: Class[_] if c == java.lang.Float.TYPE => FloatType
case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType))
}
}
private[hive] trait HiveFunctionFactory {
def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name)
def getFunctionClass(name: String) = getFunctionInfo(name).getFunctionClass
def createFunction[UDFType](name: String) =
getFunctionClass(name).newInstance.asInstanceOf[UDFType]
/** Converts hive types to native catalyst types. */
def unwrap(a: Any): Any = a match {
case null => null
case i: hadoopIo.IntWritable => i.get
case t: hadoopIo.Text => t.toString
case l: hadoopIo.LongWritable => l.get
case d: hadoopIo.DoubleWritable => d.get
case d: hiveIo.DoubleWritable => d.get
case s: hiveIo.ShortWritable => s.get
case b: hadoopIo.BooleanWritable => b.get
case b: hiveIo.ByteWritable => b.get
case b: hadoopIo.FloatWritable => b.get
case b: hadoopIo.BytesWritable => {
val bytes = new Array[Byte](b.getLength)
System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength)
bytes
}
case t: hiveIo.TimestampWritable => t.getTimestamp
case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue())
case list: java.util.List[_] => list.map(unwrap)
case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v)) }.toMap
case array: Array[_] => array.map(unwrap).toSeq
case p: java.lang.Short => p
case p: java.lang.Long => p
case p: java.lang.Float => p
case p: java.lang.Integer => p
case p: java.lang.Double => p
case p: java.lang.Byte => p
case p: java.lang.Boolean => p
case str: String => str
case p: java.math.BigDecimal => p
case p: Array[Byte] => p
case p: java.sql.Timestamp => p
}
val functionClassName: String
def createFunction[UDFType]() =
getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
}
private[hive] abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory {
......@@ -160,19 +83,17 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu
type UDFType
type EvaluatedType = Any
val name: String
def nullable = true
def references = children.flatMap(_.references).toSet
// FunctionInfo is not serializable so we must look it up here again.
lazy val functionInfo = getFunctionInfo(name)
lazy val function = createFunction[UDFType](name)
lazy val function = createFunction[UDFType]()
override def toString = s"$nodeName#${functionInfo.getDisplayName}(${children.mkString(",")})"
override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
}
private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf {
private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression])
extends HiveUdf {
import org.apache.spark.sql.hive.HiveFunctionRegistry._
type UDFType = UDF
......@@ -226,7 +147,7 @@ private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression])
}
}
private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression])
private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression])
extends HiveUdf with HiveInspectors {
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
......@@ -277,131 +198,8 @@ private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression])
}
}
private[hive] trait HiveInspectors {
def unwrapData(data: Any, oi: ObjectInspector): Any = oi match {
case hvoi: HiveVarcharObjectInspector =>
if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue
case hdoi: HiveDecimalObjectInspector =>
if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data)
case li: ListObjectInspector =>
Option(li.getList(data))
.map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq)
.orNull
case mi: MapObjectInspector =>
Option(mi.getMap(data)).map(
_.map {
case (k,v) =>
(unwrapData(k, mi.getMapKeyObjectInspector),
unwrapData(v, mi.getMapValueObjectInspector))
}.toMap).orNull
case si: StructObjectInspector =>
val allRefs = si.getAllStructFieldRefs
new GenericRow(
allRefs.map(r =>
unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray)
}
/** Converts native catalyst types to the types expected by Hive */
def wrap(a: Any): AnyRef = a match {
case s: String => new hadoopIo.Text(s) // TODO why should be Text?
case i: Int => i: java.lang.Integer
case b: Boolean => b: java.lang.Boolean
case f: Float => f: java.lang.Float
case d: Double => d: java.lang.Double
case l: Long => l: java.lang.Long
case l: Short => l: java.lang.Short
case l: Byte => l: java.lang.Byte
case b: BigDecimal => b.bigDecimal
case b: Array[Byte] => b
case t: java.sql.Timestamp => t
case s: Seq[_] => seqAsJavaList(s.map(wrap))
case m: Map[_,_] =>
mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) })
case null => null
}
def toInspector(dataType: DataType): ObjectInspector = dataType match {
case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
case MapType(keyType, valueType) =>
ObjectInspectorFactory.getStandardMapObjectInspector(
toInspector(keyType), toInspector(valueType))
case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector
case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector
case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector
case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector
case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector
case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector
case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector
case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector
case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector
case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
case StructType(fields) =>
ObjectInspectorFactory.getStandardStructObjectInspector(
fields.map(f => f.name), fields.map(f => toInspector(f.dataType)))
}
def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
case s: StructObjectInspector =>
StructType(s.getAllStructFieldRefs.map(f => {
types.StructField(
f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true)
}))
case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector))
case m: MapObjectInspector =>
MapType(
inspectorToDataType(m.getMapKeyObjectInspector),
inspectorToDataType(m.getMapValueObjectInspector))
case _: WritableStringObjectInspector => StringType
case _: JavaStringObjectInspector => StringType
case _: WritableIntObjectInspector => IntegerType
case _: JavaIntObjectInspector => IntegerType
case _: WritableDoubleObjectInspector => DoubleType
case _: JavaDoubleObjectInspector => DoubleType
case _: WritableBooleanObjectInspector => BooleanType
case _: JavaBooleanObjectInspector => BooleanType
case _: WritableLongObjectInspector => LongType
case _: JavaLongObjectInspector => LongType
case _: WritableShortObjectInspector => ShortType
case _: JavaShortObjectInspector => ShortType
case _: WritableByteObjectInspector => ByteType
case _: JavaByteObjectInspector => ByteType
case _: WritableFloatObjectInspector => FloatType
case _: JavaFloatObjectInspector => FloatType
case _: WritableBinaryObjectInspector => BinaryType
case _: JavaBinaryObjectInspector => BinaryType
case _: WritableHiveDecimalObjectInspector => DecimalType
case _: JavaHiveDecimalObjectInspector => DecimalType
case _: WritableTimestampObjectInspector => TimestampType
case _: JavaTimestampObjectInspector => TimestampType
}
implicit class typeInfoConversions(dt: DataType) {
import org.apache.hadoop.hive.serde2.typeinfo._
import TypeInfoFactory._
def toTypeInfo: TypeInfo = dt match {
case BinaryType => binaryTypeInfo
case BooleanType => booleanTypeInfo
case ByteType => byteTypeInfo
case DoubleType => doubleTypeInfo
case FloatType => floatTypeInfo
case IntegerType => intTypeInfo
case LongType => longTypeInfo
case ShortType => shortTypeInfo
case StringType => stringTypeInfo
case DecimalType => decimalTypeInfo
case TimestampType => timestampTypeInfo
case NullType => voidTypeInfo
}
}
}
private[hive] case class HiveGenericUdaf(
name: String,
functionClassName: String,
children: Seq[Expression]) extends AggregateExpression
with HiveInspectors
with HiveFunctionFactory {
......@@ -409,7 +207,7 @@ private[hive] case class HiveGenericUdaf(
type UDFType = AbstractGenericUDAFResolver
@transient
protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name)
protected lazy val resolver: AbstractGenericUDAFResolver = createFunction()
@transient
protected lazy val objectInspector = {
......@@ -426,9 +224,9 @@ private[hive] case class HiveGenericUdaf(
def references: Set[Attribute] = children.map(_.references).flatten.toSet
override def toString = s"$nodeName#$name(${children.mkString(",")})"
override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
def newInstance() = new HiveUdafFunction(name, children, this)
def newInstance() = new HiveUdafFunction(functionClassName, children, this)
}
/**
......@@ -443,7 +241,7 @@ private[hive] case class HiveGenericUdaf(
* user defined aggregations, which have clean semantics even in a partitioned execution.
*/
private[hive] case class HiveGenericUdtf(
name: String,
functionClassName: String,
aliasNames: Seq[String],
children: Seq[Expression])
extends Generator with HiveInspectors with HiveFunctionFactory {
......@@ -451,7 +249,7 @@ private[hive] case class HiveGenericUdtf(
override def references = children.flatMap(_.references).toSet
@transient
protected lazy val function: GenericUDTF = createFunction(name)
protected lazy val function: GenericUDTF = createFunction()
protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
......@@ -506,11 +304,11 @@ private[hive] case class HiveGenericUdtf(
}
}
override def toString = s"$nodeName#$name(${children.mkString(",")})"
override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
}
private[hive] case class HiveUdafFunction(
functionName: String,
functionClassName: String,
exprs: Seq[Expression],
base: AggregateExpression)
extends AggregateFunction
......@@ -519,7 +317,7 @@ private[hive] case class HiveUdafFunction(
def this() = this(null, null, null)
private val resolver = createFunction[AbstractGenericUDAFResolver](functionName)
private val resolver = createFunction[AbstractGenericUDAFResolver]()
private val inspectors = exprs.map(_.dataType).map(toInspector).toArray
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment