Skip to content
Snippets Groups Projects
Commit 4bc3bb29 authored by Jeremy Freeman's avatar Jeremy Freeman Committed by Josh Rosen
Browse files

StatCounter on NumPy arrays [PYSPARK][SPARK-2012]

These changes allow StatCounters to work properly on NumPy arrays, to fix the issue reported here  (https://issues.apache.org/jira/browse/SPARK-2012).

If NumPy is installed, the NumPy functions ``maximum``, ``minimum``, and ``sqrt``, which work on arrays, are used to merge statistics. If not, we fall back on scalar operators, so it will work on arrays with NumPy, but will also work without NumPy.

New unit tests added, along with a check for NumPy in the tests.

Author: Jeremy Freeman <the.freeman.lab@gmail.com>

Closes #1725 from freeman-lab/numpy-max-statcounter and squashes the following commits:

fe973b1 [Jeremy Freeman] Avoid duplicate array import in tests
7f0e397 [Jeremy Freeman] Refactored check for numpy
8e764dd [Jeremy Freeman] Explicit numpy imports
875414c [Jeremy Freeman] Fixed indents
1c8a832 [Jeremy Freeman] Unit tests for StatCounter with NumPy arrays
176a127 [Jeremy Freeman] Use numpy arrays in StatCounter
parent fda47598
No related branches found
No related tags found
No related merge requests found
...@@ -20,6 +20,13 @@ ...@@ -20,6 +20,13 @@
import copy import copy
import math import math
try:
from numpy import maximum, minimum, sqrt
except ImportError:
maximum = max
minimum = min
sqrt = math.sqrt
class StatCounter(object): class StatCounter(object):
...@@ -39,10 +46,8 @@ class StatCounter(object): ...@@ -39,10 +46,8 @@ class StatCounter(object):
self.n += 1 self.n += 1
self.mu += delta / self.n self.mu += delta / self.n
self.m2 += delta * (value - self.mu) self.m2 += delta * (value - self.mu)
if self.maxValue < value: self.maxValue = maximum(self.maxValue, value)
self.maxValue = value self.minValue = minimum(self.minValue, value)
if self.minValue > value:
self.minValue = value
return self return self
...@@ -70,8 +75,8 @@ class StatCounter(object): ...@@ -70,8 +75,8 @@ class StatCounter(object):
else: else:
self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n) self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n)
self.maxValue = max(self.maxValue, other.maxValue) self.maxValue = maximum(self.maxValue, other.maxValue)
self.minValue = min(self.minValue, other.minValue) self.minValue = minimum(self.minValue, other.minValue)
self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n) self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n)
self.n += other.n self.n += other.n
...@@ -115,14 +120,14 @@ class StatCounter(object): ...@@ -115,14 +120,14 @@ class StatCounter(object):
# Return the standard deviation of the values. # Return the standard deviation of the values.
def stdev(self): def stdev(self):
return math.sqrt(self.variance()) return sqrt(self.variance())
# #
# Return the sample standard deviation of the values, which corrects for bias in estimating the # Return the sample standard deviation of the values, which corrects for bias in estimating the
# variance by dividing by N-1 instead of N. # variance by dividing by N-1 instead of N.
# #
def sampleStdev(self): def sampleStdev(self):
return math.sqrt(self.sampleVariance()) return sqrt(self.sampleVariance())
def __repr__(self): def __repr__(self):
return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" %
......
...@@ -38,12 +38,19 @@ from pyspark.serializers import read_int ...@@ -38,12 +38,19 @@ from pyspark.serializers import read_int
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
_have_scipy = False _have_scipy = False
_have_numpy = False
try: try:
import scipy.sparse import scipy.sparse
_have_scipy = True _have_scipy = True
except: except:
# No SciPy, but that's okay, we'll skip those tests # No SciPy, but that's okay, we'll skip those tests
pass pass
try:
import numpy as np
_have_numpy = True
except:
# No NumPy, but that's okay, we'll skip those tests
pass
SPARK_HOME = os.environ["SPARK_HOME"] SPARK_HOME = os.environ["SPARK_HOME"]
...@@ -914,9 +921,26 @@ class SciPyTests(PySparkTestCase): ...@@ -914,9 +921,26 @@ class SciPyTests(PySparkTestCase):
self.assertEqual(expected, observed) self.assertEqual(expected, observed)
@unittest.skipIf(not _have_numpy, "NumPy not installed")
class NumPyTests(PySparkTestCase):
"""General PySpark tests that depend on numpy """
def test_statcounter_array(self):
x = self.sc.parallelize([np.array([1.0,1.0]), np.array([2.0,2.0]), np.array([3.0,3.0])])
s = x.stats()
self.assertSequenceEqual([2.0,2.0], s.mean().tolist())
self.assertSequenceEqual([1.0,1.0], s.min().tolist())
self.assertSequenceEqual([3.0,3.0], s.max().tolist())
self.assertSequenceEqual([1.0,1.0], s.sampleStdev().tolist())
if __name__ == "__main__": if __name__ == "__main__":
if not _have_scipy: if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed" print "NOTE: Skipping SciPy tests as it does not seem to be installed"
if not _have_numpy:
print "NOTE: Skipping NumPy tests as it does not seem to be installed"
unittest.main() unittest.main()
if not _have_scipy: if not _have_scipy:
print "NOTE: SciPy tests were skipped as it does not seem to be installed" print "NOTE: SciPy tests were skipped as it does not seem to be installed"
if not _have_numpy:
print "NOTE: NumPy tests were skipped as it does not seem to be installed"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment