Skip to content
Snippets Groups Projects
Commit db436e36 authored by Davies Liu's avatar Davies Liu Committed by Josh Rosen
Browse files

[SPARK-2871] [PySpark] add `key` argument for max(), min() and top(n)

RDD.max(key=None)

        param key: A function used to generate key for comparing

        >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0])
        >>> rdd.max()
        43.0
        >>> rdd.max(key=str)
        5.0

RDD.min(key=None)

        Find the minimum item in this RDD.

        param key: A function used to generate key for comparing

        >>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0])
        >>> rdd.min()
        2.0
        >>> rdd.min(key=str)
        10.0

RDD.top(num, key=None)

        Get the top N elements from a RDD.

        Note: It returns the list sorted in descending order.
        >>> sc.parallelize([10, 4, 2, 12, 3]).top(1)
        [12]
        >>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2)
        [6, 5]
        >>> sc.parallelize([10, 4, 2, 12, 3]).top(3, key=str)
        [4, 3, 2]

Author: Davies Liu <davies.liu@gmail.com>

Closes #2094 from davies/cmp and squashes the following commits:

ccbaf25 [Davies Liu] add `key` to top()
ad7e374 [Davies Liu] fix tests
2f63512 [Davies Liu] change `comp` to `key` in min/max
dd91e08 [Davies Liu] add `comp` argument for RDD.max() and RDD.min()
parent 3519b5e8
No related branches found
No related tags found
No related merge requests found
......@@ -810,23 +810,37 @@ class RDD(object):
return self.mapPartitions(func).fold(zeroValue, combOp)
def max(self):
def max(self, key=None):
"""
Find the maximum item in this RDD.
>>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).max()
@param key: A function used to generate key for comparing
>>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0])
>>> rdd.max()
43.0
>>> rdd.max(key=str)
5.0
"""
return self.reduce(max)
if key is None:
return self.reduce(max)
return self.reduce(lambda a, b: max(a, b, key=key))
def min(self):
def min(self, key=None):
"""
Find the minimum item in this RDD.
>>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).min()
1.0
@param key: A function used to generate key for comparing
>>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0])
>>> rdd.min()
2.0
>>> rdd.min(key=str)
10.0
"""
return self.reduce(min)
if key is None:
return self.reduce(min)
return self.reduce(lambda a, b: min(a, b, key=key))
def sum(self):
"""
......@@ -924,7 +938,7 @@ class RDD(object):
return m1
return self.mapPartitions(countPartition).reduce(mergeMaps)
def top(self, num):
def top(self, num, key=None):
"""
Get the top N elements from a RDD.
......@@ -933,20 +947,16 @@ class RDD(object):
[12]
>>> sc.parallelize([2, 3, 4, 5, 6], 2).top(2)
[6, 5]
>>> sc.parallelize([10, 4, 2, 12, 3]).top(3, key=str)
[4, 3, 2]
"""
def topIterator(iterator):
q = []
for k in iterator:
if len(q) < num:
heapq.heappush(q, k)
else:
heapq.heappushpop(q, k)
yield q
yield heapq.nlargest(num, iterator, key=key)
def merge(a, b):
return next(topIterator(a + b))
return heapq.nlargest(num, a + b, key=key)
return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True)
return self.mapPartitions(topIterator).reduce(merge)
def takeOrdered(self, num, key=None):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment