diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 1911588309affdb0b03dc5c9f16ad79101912fb6..9ca303a974cd4aa934ba44c57c8c78c80fcd177f 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -169,16 +169,20 @@ class GroupedData(object):
 
     @since(1.6)
     def pivot(self, pivot_col, values=None):
-        """Pivots a column of the current DataFrame and perform the specified aggregation.
+        """
+        Pivots a column of the current [[DataFrame]] and perform the specified aggregation.
+        There are two versions of pivot function: one that requires the caller to specify the list
+        of distinct values to pivot on, and one that does not. The latter is more concise but less
+        efficient, because Spark needs to first compute the list of distinct values internally.
 
-        :param pivot_col: Column to pivot
-        :param values: Optional list of values of pivot column that will be translated to columns in
-            the output DataFrame. If values are not provided the method will do an immediate call
-            to .distinct() on the pivot column.
+        :param pivot_col: Name of the column to pivot.
+        :param values: List of values that will be translated to columns in the output DataFrame.
 
+        // Compute the sum of earnings for each year by course with each course as a separate column
         >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect()
         [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)]
 
+        // Or without specifying column values (less efficient)
         >>> df4.groupBy("year").pivot("course").sum("earnings").collect()
         [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
         """