diff --git a/python/pyspark/statcounter.py b/python/pyspark/statcounter.py index e287bd3da1f6178df81b40861803ec4070ff6ad3..1e597d64e03fe6f436a4a6c87a7eb58e5b0bfc1c 100644 --- a/python/pyspark/statcounter.py +++ b/python/pyspark/statcounter.py @@ -20,6 +20,13 @@ import copy import math +try: + from numpy import maximum, minimum, sqrt +except ImportError: + maximum = max + minimum = min + sqrt = math.sqrt + class StatCounter(object): @@ -39,10 +46,8 @@ class StatCounter(object): self.n += 1 self.mu += delta / self.n self.m2 += delta * (value - self.mu) - if self.maxValue < value: - self.maxValue = value - if self.minValue > value: - self.minValue = value + self.maxValue = maximum(self.maxValue, value) + self.minValue = minimum(self.minValue, value) return self @@ -70,8 +75,8 @@ class StatCounter(object): else: self.mu = (self.mu * self.n + other.mu * other.n) / (self.n + other.n) - self.maxValue = max(self.maxValue, other.maxValue) - self.minValue = min(self.minValue, other.minValue) + self.maxValue = maximum(self.maxValue, other.maxValue) + self.minValue = minimum(self.minValue, other.minValue) self.m2 += other.m2 + (delta * delta * self.n * other.n) / (self.n + other.n) self.n += other.n @@ -115,14 +120,14 @@ class StatCounter(object): # Return the standard deviation of the values. def stdev(self): - return math.sqrt(self.variance()) + return sqrt(self.variance()) # # Return the sample standard deviation of the values, which corrects for bias in estimating the # variance by dividing by N-1 instead of N. # def sampleStdev(self): - return math.sqrt(self.sampleVariance()) + return sqrt(self.sampleVariance()) def __repr__(self): return ("(count: %s, mean: %s, stdev: %s, max: %s, min: %s)" % diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index c29deb9574ea28650af70902a74d798d3921581a..16fb5a925622034e3177f83514ac824714e9094b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -38,12 +38,19 @@ from pyspark.serializers import read_int from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger _have_scipy = False +_have_numpy = False try: import scipy.sparse _have_scipy = True except: # No SciPy, but that's okay, we'll skip those tests pass +try: + import numpy as np + _have_numpy = True +except: + # No NumPy, but that's okay, we'll skip those tests + pass SPARK_HOME = os.environ["SPARK_HOME"] @@ -914,9 +921,26 @@ class SciPyTests(PySparkTestCase): self.assertEqual(expected, observed) +@unittest.skipIf(not _have_numpy, "NumPy not installed") +class NumPyTests(PySparkTestCase): + """General PySpark tests that depend on numpy """ + + def test_statcounter_array(self): + x = self.sc.parallelize([np.array([1.0,1.0]), np.array([2.0,2.0]), np.array([3.0,3.0])]) + s = x.stats() + self.assertSequenceEqual([2.0,2.0], s.mean().tolist()) + self.assertSequenceEqual([1.0,1.0], s.min().tolist()) + self.assertSequenceEqual([3.0,3.0], s.max().tolist()) + self.assertSequenceEqual([1.0,1.0], s.sampleStdev().tolist()) + + if __name__ == "__main__": if not _have_scipy: print "NOTE: Skipping SciPy tests as it does not seem to be installed" + if not _have_numpy: + print "NOTE: Skipping NumPy tests as it does not seem to be installed" unittest.main() if not _have_scipy: print "NOTE: SciPy tests were skipped as it does not seem to be installed" + if not _have_numpy: + print "NOTE: NumPy tests were skipped as it does not seem to be installed"