diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index f4c42bbc5b03d3e04b3b862972b618107147c22e..cd4e5a239ec665ad675a7936d045951800e79274 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1128,7 +1128,10 @@ private[hive] object HiveQl {
         Explode(attributes, nodeToExpr(child))
 
       case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) =>
-        HiveGenericUdtf(functionName, attributes, children.map(nodeToExpr))
+        HiveGenericUdtf(
+          new HiveFunctionWrapper(functionName),
+          attributes,
+          children.map(nodeToExpr))
 
       case a: ASTNode =>
         throw new NotImplementedError(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index fecf8faaf4cda3958d052f8b9208c38f5ee0879e..ed2e96df8ad77cb2e0eea8947de4c51d8f23a2b7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -54,46 +54,31 @@ private[hive] abstract class HiveFunctionRegistry
     val functionClassName = functionInfo.getFunctionClass.getName
 
     if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
-      HiveSimpleUdf(functionClassName, children)
+      HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children)
     } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
-      HiveGenericUdf(functionClassName, children)
+      HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children)
     } else if (
          classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
-      HiveGenericUdaf(functionClassName, children)
+      HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children)
     } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
-      HiveUdaf(functionClassName, children)
+      HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
     } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
-      HiveGenericUdtf(functionClassName, Nil, children)
+      HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), Nil, children)
     } else {
       sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
     }
   }
 }
 
-private[hive] trait HiveFunctionFactory {
-  val functionClassName: String
-
-  def createFunction[UDFType]() =
-    getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
-}
-
-private[hive] abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory {
-  self: Product =>
-
-  type UDFType
+private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+  extends Expression with HiveInspectors with Logging {
   type EvaluatedType = Any
+  type UDFType = UDF
 
   def nullable = true
 
-  lazy val function = createFunction[UDFType]()
-
-  override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
-}
-
-private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression])
-  extends HiveUdf with HiveInspectors {
-
-  type UDFType = UDF
+  @transient
+  lazy val function = funcWrapper.createFunction[UDFType]()
 
   @transient
   protected lazy val method =
@@ -131,6 +116,8 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[
         .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*),
       returnInspector)
   }
+
+  override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
 }
 
 // Adapter from Catalyst ExpressionResult to Hive DeferredObject
@@ -144,16 +131,23 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector)
   override def get(): AnyRef = wrap(func(), oi)
 }
 
-private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression])
-  extends HiveUdf with HiveInspectors {
+private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
+  extends Expression with HiveInspectors with Logging {
   type UDFType = GenericUDF
+  type EvaluatedType = Any
+
+  def nullable = true
+
+  @transient
+  lazy val function = funcWrapper.createFunction[UDFType]()
 
   @transient
   protected lazy val argumentInspectors = children.map(toInspector)
 
   @transient
-  protected lazy val returnInspector =
+  protected lazy val returnInspector = {
     function.initializeAndFoldConstants(argumentInspectors.toArray)
+  }
 
   @transient
   protected lazy val isUDFDeterministic = {
@@ -183,18 +177,19 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq
     }
     unwrap(function.evaluate(deferedObjects), returnInspector)
   }
+
+  override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
 }
 
 private[hive] case class HiveGenericUdaf(
-    functionClassName: String,
+    funcWrapper: HiveFunctionWrapper,
     children: Seq[Expression]) extends AggregateExpression
-  with HiveInspectors
-  with HiveFunctionFactory {
+  with HiveInspectors {
 
   type UDFType = AbstractGenericUDAFResolver
 
   @transient
-  protected lazy val resolver: AbstractGenericUDAFResolver = createFunction()
+  protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction()
 
   @transient
   protected lazy val objectInspector  = {
@@ -209,22 +204,22 @@ private[hive] case class HiveGenericUdaf(
 
   def nullable: Boolean = true
 
-  override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
+  override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
 
-  def newInstance() = new HiveUdafFunction(functionClassName, children, this)
+  def newInstance() = new HiveUdafFunction(funcWrapper, children, this)
 }
 
 /** It is used as a wrapper for the hive functions which uses UDAF interface */
 private[hive] case class HiveUdaf(
-    functionClassName: String,
+    funcWrapper: HiveFunctionWrapper,
     children: Seq[Expression]) extends AggregateExpression
-  with HiveInspectors
-  with HiveFunctionFactory {
+  with HiveInspectors {
 
   type UDFType = UDAF
 
   @transient
-  protected lazy val resolver: AbstractGenericUDAFResolver = new GenericUDAFBridge(createFunction())
+  protected lazy val resolver: AbstractGenericUDAFResolver =
+    new GenericUDAFBridge(funcWrapper.createFunction())
 
   @transient
   protected lazy val objectInspector  = {
@@ -239,10 +234,10 @@ private[hive] case class HiveUdaf(
 
   def nullable: Boolean = true
 
-  override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
+  override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
 
   def newInstance() =
-    new HiveUdafFunction(functionClassName, children, this, true)
+    new HiveUdafFunction(funcWrapper, children, this, true)
 }
 
 /**
@@ -257,13 +252,13 @@ private[hive] case class HiveUdaf(
  * user defined aggregations, which have clean semantics even in a partitioned execution.
  */
 private[hive] case class HiveGenericUdtf(
-    functionClassName: String,
+    funcWrapper: HiveFunctionWrapper,
     aliasNames: Seq[String],
     children: Seq[Expression])
-  extends Generator with HiveInspectors with HiveFunctionFactory {
+  extends Generator with HiveInspectors {
 
   @transient
-  protected lazy val function: GenericUDTF = createFunction()
+  protected lazy val function: GenericUDTF = funcWrapper.createFunction()
 
   @transient
   protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
@@ -320,25 +315,24 @@ private[hive] case class HiveGenericUdtf(
     }
   }
 
-  override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
+  override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
 }
 
 private[hive] case class HiveUdafFunction(
-    functionClassName: String,
+    funcWrapper: HiveFunctionWrapper,
     exprs: Seq[Expression],
     base: AggregateExpression,
     isUDAFBridgeRequired: Boolean = false)
   extends AggregateFunction
-  with HiveInspectors
-  with HiveFunctionFactory {
+  with HiveInspectors {
 
   def this() = this(null, null, null)
 
   private val resolver =
     if (isUDAFBridgeRequired) {
-      new GenericUDAFBridge(createFunction[UDAF]())
+      new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
     } else {
-      createFunction[AbstractGenericUDAFResolver]()
+      funcWrapper.createFunction[AbstractGenericUDAFResolver]()
     }
 
   private val inspectors = exprs.map(_.dataType).map(toInspector).toArray
@@ -361,3 +355,4 @@ private[hive] case class HiveUdafFunction(
     function.iterate(buffer, inputs)
   }
 }
+
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index 872f28d514efebac9aa49d8e3fb5302458bbd5bc..5fcaf671a80de803d2603ace04fff9ce0e71e61d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -60,6 +60,13 @@ class HiveUdfSuite extends QueryTest {
         |       getStruct(1).f5 FROM src LIMIT 1
       """.stripMargin).first() === Row(1, 2, 3, 4, 5))
   }
+  
+  test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") {
+    checkAnswer(
+      sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"),
+      8
+    )
+  }
 
   test("hive struct udf") {
     sql(
diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
index 76f09cbcdec9957973a1eab83536c97baac5802b..754ffc422072d7e8a27f9dbe12ede7c7cae0a753 100644
--- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
+++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
@@ -43,6 +43,17 @@ import scala.language.implicitConversions
 
 import org.apache.spark.sql.catalyst.types.DecimalType
 
+class HiveFunctionWrapper(var functionClassName: String) extends java.io.Serializable {
+  // for Serialization
+  def this() = this(null)
+
+  import org.apache.spark.util.Utils._
+  def createFunction[UDFType <: AnyRef](): UDFType = {
+    getContextOrSparkClassLoader
+      .loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
+  }
+}
+
 /**
  * A compatibility layer for interacting with Hive version 0.12.0.
  */
diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
index 91f7ceac211778bdb4d589009d13536e2689565f..7c8cbf10c1c30259dd765df5bb0f7516683daacf 100644
--- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
+++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive
 
 import java.util.{ArrayList => JArrayList}
 import java.util.Properties
+
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
 import org.apache.hadoop.mapred.InputFormat
@@ -42,6 +43,112 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal
 import scala.collection.JavaConversions._
 import scala.language.implicitConversions
 
+
+/**
+ * This class provides the UDF creation and also the UDF instance serialization and
+ * de-serialization cross process boundary.
+ * 
+ * Detail discussion can be found at https://github.com/apache/spark/pull/3640
+ *
+ * @param functionClassName UDF class name
+ */
+class HiveFunctionWrapper(var functionClassName: String) extends java.io.Externalizable {
+  // for Serialization
+  def this() = this(null)
+
+  import java.io.{OutputStream, InputStream}
+  import com.esotericsoftware.kryo.Kryo
+  import org.apache.spark.util.Utils._
+  import org.apache.hadoop.hive.ql.exec.Utilities
+  import org.apache.hadoop.hive.ql.exec.UDF
+
+  @transient
+  private val methodDeSerialize = {
+    val method = classOf[Utilities].getDeclaredMethod(
+      "deserializeObjectByKryo",
+      classOf[Kryo],
+      classOf[InputStream],
+      classOf[Class[_]])
+    method.setAccessible(true)
+
+    method
+  }
+
+  @transient
+  private val methodSerialize = {
+    val method = classOf[Utilities].getDeclaredMethod(
+      "serializeObjectByKryo",
+      classOf[Kryo],
+      classOf[Object],
+      classOf[OutputStream])
+    method.setAccessible(true)
+
+    method
+  }
+
+  def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = {
+    methodDeSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), is, clazz)
+      .asInstanceOf[UDFType]
+  }
+
+  def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = {
+    methodSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), function, out)
+  }
+
+  private var instance: AnyRef = null
+
+  def writeExternal(out: java.io.ObjectOutput) {
+    // output the function name
+    out.writeUTF(functionClassName)
+
+    // Write a flag if instance is null or not
+    out.writeBoolean(instance != null)
+    if (instance != null) {
+      // Some of the UDF are serializable, but some others are not
+      // Hive Utilities can handle both cases
+      val baos = new java.io.ByteArrayOutputStream()
+      serializePlan(instance, baos)
+      val functionInBytes = baos.toByteArray
+
+      // output the function bytes
+      out.writeInt(functionInBytes.length)
+      out.write(functionInBytes, 0, functionInBytes.length)
+    }
+  }
+
+  def readExternal(in: java.io.ObjectInput) {
+    // read the function name
+    functionClassName = in.readUTF()
+
+    if (in.readBoolean()) {
+      // if the instance is not null
+      // read the function in bytes
+      val functionInBytesLength = in.readInt()
+      val functionInBytes = new Array[Byte](functionInBytesLength)
+      in.read(functionInBytes, 0, functionInBytesLength)
+
+      // deserialize the function object via Hive Utilities
+      instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes),
+        getContextOrSparkClassLoader.loadClass(functionClassName))
+    }
+  }
+
+  def createFunction[UDFType <: AnyRef](): UDFType = {
+    if (instance != null) {
+      instance.asInstanceOf[UDFType]
+    } else {
+      val func = getContextOrSparkClassLoader
+                   .loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
+      if (!func.isInstanceOf[UDF]) {
+        // We cache the function if it's no the Simple UDF,
+        // as we always have to create new instance for Simple UDF
+        instance = func
+      }
+      func
+    }
+  }
+}
+
 /**
  * A compatibility layer for interacting with Hive version 0.13.1.
  */