Skip to content
Snippets Groups Projects
Commit 8f4aaba0 authored by FlytxtRnD's avatar FlytxtRnD Committed by Joseph K. Bradley
Browse files

[SPARK-7651] [MLLIB] [PYSPARK] GMM predict, predictSoft should raise error on bad input

In the Python API for Gaussian Mixture Model, predict() and predictSoft() methods should raise an error when the input argument is not an RDD.

Author: FlytxtRnD <meethu.mathew@flytxt.com>

Closes #6180 from FlytxtRnD/GmmPredictException and squashes the following commits:

4b6aa11 [FlytxtRnD] Raise error if the input to predict()/predictSoft() is not an RDD
parent f96b85ab
No related branches found
No related tags found
No related merge requests found
......@@ -212,6 +212,9 @@ class GaussianMixtureModel(object):
if isinstance(x, RDD):
cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
return cluster_labels
else:
raise TypeError("x should be represented by an RDD, "
"but got %s." % type(x))
def predictSoft(self, x):
"""
......@@ -225,6 +228,9 @@ class GaussianMixtureModel(object):
membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
_convert_to_vector(self._weights), means, sigmas)
return membership_matrix.map(lambda x: pyarray.array('d', x))
else:
raise TypeError("x should be represented by an RDD, "
"but got %s." % type(x))
class GaussianMixture(object):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment