Skip to content
Snippets Groups Projects
Commit 4bc3bb29 authored by Jeremy Freeman's avatar Jeremy Freeman Committed by Josh Rosen
Browse files

StatCounter on NumPy arrays [PYSPARK][SPARK-2012]

These changes allow StatCounters to work properly on NumPy arrays, to fix the issue reported here  (https://issues.apache.org/jira/browse/SPARK-2012).

If NumPy is installed, the NumPy functions ``maximum``, ``minimum``, and ``sqrt``, which work on arrays, are used to merge statistics. If not, we fall back on scalar operators, so it will work on arrays with NumPy, but will also work without NumPy.

New unit tests added, along with a check for NumPy in the tests.

Author: Jeremy Freeman <the.freeman.lab@gmail.com>

Closes #1725 from freeman-lab/numpy-max-statcounter and squashes the following commits:

fe973b1 [Jeremy Freeman] Avoid duplicate array import in tests
7f0e397 [Jeremy Freeman] Refactored check for numpy
8e764dd [Jeremy Freeman] Explicit numpy imports
875414c [Jeremy Freeman] Fixed indents
1c8a832 [Jeremy Freeman] Unit tests for StatCounter with NumPy arrays
176a127 [Jeremy Freeman] Use numpy arrays in StatCounter
parent fda47598
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,13 @@
import copy
import math
try:
from numpy import maximum, minimum, sqrt
except ImportError:
maximum = max
minimum = min
sqrt = math.sqrt
class StatCounter(object):
......@@ -39,10 +46,8 @@ class StatCounter(object):
self.n += 1
self.mu += delta / self.n
self.m2 += delta * (value - self.mu)
if self.maxValue < value:
self.maxValue = value
if self.minValue > value:
self.minValue = value
self.maxValue = maximum(self.maxValue, value)
self.minValue = minimum(self.minValue, value)
return self
......@@ -70,8 +75,8 @@ class StatCounter(object):
else:
self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n)
self.maxValue = max(self.maxValue, other.maxValue)
self.minValue = min(self.minValue, other.minValue)
self.maxValue = maximum(self.maxValue, other.maxValue)
self.minValue = minimum(self.minValue, other.minValue)
self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n)
self.n += other.n
......@@ -115,14 +120,14 @@ class StatCounter(object):
# Return the standard deviation of the values.
def stdev(self):
return math.sqrt(self.variance())
return sqrt(self.variance())
#
# Return the sample standard deviation of the values, which corrects for bias in estimating the
# variance by dividing by N-1 instead of N.
#
def sampleStdev(self):
return math.sqrt(self.sampleVariance())
return sqrt(self.sampleVariance())
def __repr__(self):
return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" %
......
......@@ -38,12 +38,19 @@ from pyspark.serializers import read_int
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
_have_scipy = False
_have_numpy = False
try:
import scipy.sparse
_have_scipy = True
except:
# No SciPy, but that's okay, we'll skip those tests
pass
try:
import numpy as np
_have_numpy = True
except:
# No NumPy, but that's okay, we'll skip those tests
pass
SPARK_HOME = os.environ["SPARK_HOME"]
......@@ -914,9 +921,26 @@ class SciPyTests(PySparkTestCase):
self.assertEqual(expected, observed)
@unittest.skipIf(not _have_numpy, "NumPy not installed")
class NumPyTests(PySparkTestCase):
"""General PySpark tests that depend on numpy """
def test_statcounter_array(self):
x = self.sc.parallelize([np.array([1.0,1.0]), np.array([2.0,2.0]), np.array([3.0,3.0])])
s = x.stats()
self.assertSequenceEqual([2.0,2.0], s.mean().tolist())
self.assertSequenceEqual([1.0,1.0], s.min().tolist())
self.assertSequenceEqual([3.0,3.0], s.max().tolist())
self.assertSequenceEqual([1.0,1.0], s.sampleStdev().tolist())
if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"
if not _have_numpy:
print "NOTE: Skipping NumPy tests as it does not seem to be installed"
unittest.main()
if not _have_scipy:
print "NOTE: SciPy tests were skipped as it does not seem to be installed"
if not _have_numpy:
print "NOTE: NumPy tests were skipped as it does not seem to be installed"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment