diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 22b1ffc900751139ae747b7c5a0609dbab0974c8..c3cc7426b177bd9a82d75c01fdad8e01ab99849f 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -324,6 +324,12 @@ class SQLTests(ReusedPySparkTestCase):
         [row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
         self.assertEqual(row[0], 6)
 
+    def test_single_udf_with_repeated_argument(self):
+        # regression test for SPARK-20685
+        self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
+        row = self.spark.sql("SELECT add(1, 1)").first()
+        self.assertEqual(tuple(row), (2, ))
+
     def test_multiple_udfs(self):
         self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType())
         [row] = self.spark.sql("SELECT double(1), double(2)").collect()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 09182829538fc0b74fbbad0f6768aa4a85d90d2e..5eeaac70b80ed28e98d5a4ebc8c0fdc6a6a4c102 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -86,22 +86,19 @@ def read_single_udf(pickleSer, infile):
 
 def read_udfs(pickleSer, infile):
     num_udfs = read_int(infile)
-    if num_udfs == 1:
-        # fast path for single UDF
-        _, udf = read_single_udf(pickleSer, infile)
-        mapper = lambda a: udf(*a)
-    else:
-        udfs = {}
-        call_udf = []
-        for i in range(num_udfs):
-            arg_offsets, udf = read_single_udf(pickleSer, infile)
-            udfs['f%d' % i] = udf
-            args = ["a[%d]" % o for o in arg_offsets]
-            call_udf.append("f%d(%s)" % (i, ", ".join(args)))
-        # Create function like this:
-        #   lambda a: (f0(a0), f1(a1, a2), f2(a3))
-        mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
-        mapper = eval(mapper_str, udfs)
+    udfs = {}
+    call_udf = []
+    for i in range(num_udfs):
+        arg_offsets, udf = read_single_udf(pickleSer, infile)
+        udfs['f%d' % i] = udf
+        args = ["a[%d]" % o for o in arg_offsets]
+        call_udf.append("f%d(%s)" % (i, ", ".join(args)))
+    # Create function like this:
+    #   lambda a: (f0(a0), f1(a1, a2), f2(a3))
+    # In the special case of a single UDF this will return a single result rather
+    # than a tuple of results; this is the format that the JVM side expects.
+    mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
+    mapper = eval(mapper_str, udfs)
 
     func = lambda _, it: map(mapper, it)
     ser = BatchedSerializer(PickleSerializer(), 100)