diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 8a478fddf0e955e54a3f177fa4b583549b344da0..146ba6f3e0d983ec7461207a85e2166833f3485f 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -123,7 +123,8 @@ class UserDefinedFunction(object):
         pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
         ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
         jdt = ssql_ctx.parseDataType(self.returnType.json())
-        judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
+        fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
+        judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env,
                                                  includes, sc.pythonExec, broadcast_vars,
                                                  sc._javaAccumulator, jdt)
         return judf
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 258464b7f230dcf60189bddf8bc68cec57b17245..b3a6a2c6a92299f6e82a23cf9fa5fc9f7d13e961 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -25,6 +25,7 @@ import pydoc
 import shutil
 import tempfile
 import pickle
+import functools
 
 import py4j
 
@@ -41,6 +42,7 @@ from pyspark.sql import SQLContext, HiveContext, Column, Row
 from pyspark.sql.types import *
 from pyspark.sql.types import UserDefinedType, _infer_type
 from pyspark.tests import ReusedPySparkTestCase
+from pyspark.sql.functions import UserDefinedFunction
 
 
 class ExamplePointUDT(UserDefinedType):
@@ -114,6 +116,35 @@ class SQLTests(ReusedPySparkTestCase):
         ReusedPySparkTestCase.tearDownClass()
         shutil.rmtree(cls.tempdir.name, ignore_errors=True)
 
+    def test_udf_with_callable(self):
+        d = [Row(number=i, squared=i**2) for i in range(10)]
+        rdd = self.sc.parallelize(d)
+        data = self.sqlCtx.createDataFrame(rdd)
+
+        class PlusFour:
+            def __call__(self, col):
+                if col is not None:
+                    return col + 4
+
+        call = PlusFour()
+        pudf = UserDefinedFunction(call, LongType())
+        res = data.select(pudf(data['number']).alias('plus_four'))
+        self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
+
+    def test_udf_with_partial_function(self):
+        d = [Row(number=i, squared=i**2) for i in range(10)]
+        rdd = self.sc.parallelize(d)
+        data = self.sqlCtx.createDataFrame(rdd)
+
+        def some_func(col, param):
+            if col is not None:
+                return col + param
+
+        pfunc = functools.partial(some_func, param=4)
+        pudf = UserDefinedFunction(pfunc, LongType())
+        res = data.select(pudf(data['number']).alias('plus_four'))
+        self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
+
     def test_udf(self):
         self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
         [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()