Skip to content
Snippets Groups Projects
Commit 85708168 authored by Davies Liu's avatar Davies Liu Committed by Xiangrui Meng
Browse files

[SPARK-4023] [MLlib] [PySpark] convert rdd into RDD of Vector

Convert the input rdd to RDD of Vector.

cc mengxr

Author: Davies Liu <davies@databricks.com>

Closes #2870 from davies/fix4023 and squashes the following commits:

1eac767 [Davies Liu] address comments
0871576 [Davies Liu] convert rdd into RDD of Vector
parent 5a8f64f3
No related branches found
No related tags found
No related merge requests found
......@@ -22,7 +22,7 @@ Python package for statistical functions in MLlib.
from functools import wraps
from pyspark import PickleSerializer
from pyspark.mllib.linalg import _to_java_object_rdd
from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd
__all__ = ['MultivariateStatisticalSummary', 'Statistics']
......@@ -107,7 +107,7 @@ class Statistics(object):
array([ 2., 0., 0., -2.])
"""
sc = rdd.ctx
jrdd = _to_java_object_rdd(rdd)
jrdd = _to_java_object_rdd(rdd.map(_convert_to_vector))
cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd)
return MultivariateStatisticalSummary(sc, cStats)
......@@ -163,14 +163,15 @@ class Statistics(object):
if type(y) == str:
raise TypeError("Use 'method=' to specify method name.")
jx = _to_java_object_rdd(x)
if not y:
jx = _to_java_object_rdd(x.map(_convert_to_vector))
resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method)
bytes = sc._jvm.SerDe.dumps(resultMat)
ser = PickleSerializer()
return ser.loads(str(bytes)).toArray()
else:
jy = _to_java_object_rdd(y)
jx = _to_java_object_rdd(x.map(float))
jy = _to_java_object_rdd(y.map(float))
return sc._jvm.PythonMLLibAPI().corr(jx, jy, method)
......
......@@ -36,6 +36,8 @@ else:
from pyspark.serializers import PickleSerializer
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
......@@ -202,6 +204,23 @@ class ListTests(PySparkTestCase):
self.assertTrue(dt_model.predict(features[3]) > 0)
class StatTests(PySparkTestCase):
# SPARK-4023
def test_col_with_different_rdds(self):
# numpy
data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10)
summary = Statistics.colStats(data)
self.assertEqual(1000, summary.count())
# array
data = self.sc.parallelize([range(10)] * 10)
summary = Statistics.colStats(data)
self.assertEqual(10, summary.count())
# array
data = self.sc.parallelize([pyarray.array("d", range(10))] * 10)
summary = Statistics.colStats(data)
self.assertEqual(10, summary.count())
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment