diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index f4c42bbc5b03d3e04b3b862972b618107147c22e..cd4e5a239ec665ad675a7936d045951800e79274 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1128,7 +1128,10 @@ private[hive] object HiveQl { Explode(attributes, nodeToExpr(child)) case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => - HiveGenericUdtf(functionName, attributes, children.map(nodeToExpr)) + HiveGenericUdtf( + new HiveFunctionWrapper(functionName), + attributes, + children.map(nodeToExpr)) case a: ASTNode => throw new NotImplementedError( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index fecf8faaf4cda3958d052f8b9208c38f5ee0879e..ed2e96df8ad77cb2e0eea8947de4c51d8f23a2b7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -54,46 +54,31 @@ private[hive] abstract class HiveFunctionRegistry val functionClassName = functionInfo.getFunctionClass.getName if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUdf(functionClassName, children) + HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdf(functionClassName, children) + HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(functionClassName, children) + HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUdaf(functionClassName, children) + HiveUdaf(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(functionClassName, Nil, children) + HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), Nil, children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") } } } -private[hive] trait HiveFunctionFactory { - val functionClassName: String - - def createFunction[UDFType]() = - getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType] -} - -private[hive] abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory { - self: Product => - - type UDFType +private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) + extends Expression with HiveInspectors with Logging { type EvaluatedType = Any + type UDFType = UDF def nullable = true - lazy val function = createFunction[UDFType]() - - override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" -} - -private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression]) - extends HiveUdf with HiveInspectors { - - type UDFType = UDF + @transient + lazy val function = funcWrapper.createFunction[UDFType]() @transient protected lazy val method = @@ -131,6 +116,8 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*), returnInspector) } + + override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } // Adapter from Catalyst ExpressionResult to Hive DeferredObject @@ -144,16 +131,23 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector) override def get(): AnyRef = wrap(func(), oi) } -private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression]) - extends HiveUdf with HiveInspectors { +private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) + extends Expression with HiveInspectors with Logging { type UDFType = GenericUDF + type EvaluatedType = Any + + def nullable = true + + @transient + lazy val function = funcWrapper.createFunction[UDFType]() @transient protected lazy val argumentInspectors = children.map(toInspector) @transient - protected lazy val returnInspector = + protected lazy val returnInspector = { function.initializeAndFoldConstants(argumentInspectors.toArray) + } @transient protected lazy val isUDFDeterministic = { @@ -183,18 +177,19 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq } unwrap(function.evaluate(deferedObjects), returnInspector) } + + override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } private[hive] case class HiveGenericUdaf( - functionClassName: String, + funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends AggregateExpression - with HiveInspectors - with HiveFunctionFactory { + with HiveInspectors { type UDFType = AbstractGenericUDAFResolver @transient - protected lazy val resolver: AbstractGenericUDAFResolver = createFunction() + protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction() @transient protected lazy val objectInspector = { @@ -209,22 +204,22 @@ private[hive] case class HiveGenericUdaf( def nullable: Boolean = true - override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" + override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - def newInstance() = new HiveUdafFunction(functionClassName, children, this) + def newInstance() = new HiveUdafFunction(funcWrapper, children, this) } /** It is used as a wrapper for the hive functions which uses UDAF interface */ private[hive] case class HiveUdaf( - functionClassName: String, + funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends AggregateExpression - with HiveInspectors - with HiveFunctionFactory { + with HiveInspectors { type UDFType = UDAF @transient - protected lazy val resolver: AbstractGenericUDAFResolver = new GenericUDAFBridge(createFunction()) + protected lazy val resolver: AbstractGenericUDAFResolver = + new GenericUDAFBridge(funcWrapper.createFunction()) @transient protected lazy val objectInspector = { @@ -239,10 +234,10 @@ private[hive] case class HiveUdaf( def nullable: Boolean = true - override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" + override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" def newInstance() = - new HiveUdafFunction(functionClassName, children, this, true) + new HiveUdafFunction(funcWrapper, children, this, true) } /** @@ -257,13 +252,13 @@ private[hive] case class HiveUdaf( * user defined aggregations, which have clean semantics even in a partitioned execution. */ private[hive] case class HiveGenericUdtf( - functionClassName: String, + funcWrapper: HiveFunctionWrapper, aliasNames: Seq[String], children: Seq[Expression]) - extends Generator with HiveInspectors with HiveFunctionFactory { + extends Generator with HiveInspectors { @transient - protected lazy val function: GenericUDTF = createFunction() + protected lazy val function: GenericUDTF = funcWrapper.createFunction() @transient protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) @@ -320,25 +315,24 @@ private[hive] case class HiveGenericUdtf( } } - override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})" + override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } private[hive] case class HiveUdafFunction( - functionClassName: String, + funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], base: AggregateExpression, isUDAFBridgeRequired: Boolean = false) extends AggregateFunction - with HiveInspectors - with HiveFunctionFactory { + with HiveInspectors { def this() = this(null, null, null) private val resolver = if (isUDAFBridgeRequired) { - new GenericUDAFBridge(createFunction[UDAF]()) + new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) } else { - createFunction[AbstractGenericUDAFResolver]() + funcWrapper.createFunction[AbstractGenericUDAFResolver]() } private val inspectors = exprs.map(_.dataType).map(toInspector).toArray @@ -361,3 +355,4 @@ private[hive] case class HiveUdafFunction( function.iterate(buffer, inputs) } } + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index 872f28d514efebac9aa49d8e3fb5302458bbd5bc..5fcaf671a80de803d2603ace04fff9ce0e71e61d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -60,6 +60,13 @@ class HiveUdfSuite extends QueryTest { | getStruct(1).f5 FROM src LIMIT 1 """.stripMargin).first() === Row(1, 2, 3, 4, 5)) } + + test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { + checkAnswer( + sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"), + 8 + ) + } test("hive struct udf") { sql( diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index 76f09cbcdec9957973a1eab83536c97baac5802b..754ffc422072d7e8a27f9dbe12ede7c7cae0a753 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -43,6 +43,17 @@ import scala.language.implicitConversions import org.apache.spark.sql.catalyst.types.DecimalType +class HiveFunctionWrapper(var functionClassName: String) extends java.io.Serializable { + // for Serialization + def this() = this(null) + + import org.apache.spark.util.Utils._ + def createFunction[UDFType <: AnyRef](): UDFType = { + getContextOrSparkClassLoader + .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + } +} + /** * A compatibility layer for interacting with Hive version 0.12.0. */ diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 91f7ceac211778bdb4d589009d13536e2689565f..7c8cbf10c1c30259dd765df5bb0f7516683daacf 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive import java.util.{ArrayList => JArrayList} import java.util.Properties + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred.InputFormat @@ -42,6 +43,112 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.collection.JavaConversions._ import scala.language.implicitConversions + +/** + * This class provides the UDF creation and also the UDF instance serialization and + * de-serialization cross process boundary. + * + * Detail discussion can be found at https://github.com/apache/spark/pull/3640 + * + * @param functionClassName UDF class name + */ +class HiveFunctionWrapper(var functionClassName: String) extends java.io.Externalizable { + // for Serialization + def this() = this(null) + + import java.io.{OutputStream, InputStream} + import com.esotericsoftware.kryo.Kryo + import org.apache.spark.util.Utils._ + import org.apache.hadoop.hive.ql.exec.Utilities + import org.apache.hadoop.hive.ql.exec.UDF + + @transient + private val methodDeSerialize = { + val method = classOf[Utilities].getDeclaredMethod( + "deserializeObjectByKryo", + classOf[Kryo], + classOf[InputStream], + classOf[Class[_]]) + method.setAccessible(true) + + method + } + + @transient + private val methodSerialize = { + val method = classOf[Utilities].getDeclaredMethod( + "serializeObjectByKryo", + classOf[Kryo], + classOf[Object], + classOf[OutputStream]) + method.setAccessible(true) + + method + } + + def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = { + methodDeSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), is, clazz) + .asInstanceOf[UDFType] + } + + def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = { + methodSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), function, out) + } + + private var instance: AnyRef = null + + def writeExternal(out: java.io.ObjectOutput) { + // output the function name + out.writeUTF(functionClassName) + + // Write a flag if instance is null or not + out.writeBoolean(instance != null) + if (instance != null) { + // Some of the UDF are serializable, but some others are not + // Hive Utilities can handle both cases + val baos = new java.io.ByteArrayOutputStream() + serializePlan(instance, baos) + val functionInBytes = baos.toByteArray + + // output the function bytes + out.writeInt(functionInBytes.length) + out.write(functionInBytes, 0, functionInBytes.length) + } + } + + def readExternal(in: java.io.ObjectInput) { + // read the function name + functionClassName = in.readUTF() + + if (in.readBoolean()) { + // if the instance is not null + // read the function in bytes + val functionInBytesLength = in.readInt() + val functionInBytes = new Array[Byte](functionInBytesLength) + in.read(functionInBytes, 0, functionInBytesLength) + + // deserialize the function object via Hive Utilities + instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes), + getContextOrSparkClassLoader.loadClass(functionClassName)) + } + } + + def createFunction[UDFType <: AnyRef](): UDFType = { + if (instance != null) { + instance.asInstanceOf[UDFType] + } else { + val func = getContextOrSparkClassLoader + .loadClass(functionClassName).newInstance.asInstanceOf[UDFType] + if (!func.isInstanceOf[UDF]) { + // We cache the function if it's no the Simple UDF, + // as we always have to create new instance for Simple UDF + instance = func + } + func + } + } +} + /** * A compatibility layer for interacting with Hive version 0.13.1. */