Skip to content
Snippets Groups Projects
Commit c4f0b4f3 authored by Travis Galoppo's avatar Travis Galoppo Committed by Xiangrui Meng
Browse files

SPARK-5020 [MLlib] GaussianMixtureModel.predictMembership() should take an RDD only

Removed unnecessary parameters to predictMembership()

CC: jkbradley

Author: Travis Galoppo <tjg2107@columbia.edu>

Closes #3854 from tgaloppo/spark-5020 and squashes the following commits:

1bf4669 [Travis Galoppo] renamed predictMembership() to predictSoft()
0f1d96e [Travis Galoppo] SPARK-5020 - Removed superfluous parameters from predictMembership()
parent fdc2aa49
No related branches found
No related tags found
No related merge requests found
......@@ -45,7 +45,7 @@ class GaussianMixtureModel(
/** Maps given points to their cluster indices. */
def predict(points: RDD[Vector]): RDD[Int] = {
val responsibilityMatrix = predictMembership(points, mu, sigma, weight, k)
val responsibilityMatrix = predictSoft(points)
responsibilityMatrix.map(r => r.indexOf(r.max))
}
......@@ -53,12 +53,7 @@ class GaussianMixtureModel(
* Given the input vectors, return the membership value of each vector
* to all mixture components.
*/
def predictMembership(
points: RDD[Vector],
mu: Array[Vector],
sigma: Array[Matrix],
weight: Array[Double],
k: Int): RDD[Array[Double]] = {
def predictSoft(points: RDD[Vector]): RDD[Array[Double]] = {
val sc = points.sparkContext
val dists = sc.broadcast {
(0 until k).map { i =>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment