From 50a1a874e1d087a6c79835b1936d0009622a97b1 Mon Sep 17 00:00:00 2001
From: FlytxtRnD <meethu.mathew@flytxt.com>
Date: Mon, 2 Feb 2015 23:04:55 -0800
Subject: [PATCH] [SPARK-5012][MLLib][PySpark]Python API for Gaussian Mixture
 Model

Python API for the Gaussian Mixture Model clustering algorithm in MLLib.

Author: FlytxtRnD <meethu.mathew@flytxt.com>

Closes #4059 from FlytxtRnD/PythonGmmWrapper and squashes the following commits:

c973ab3 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper
339b09c [FlytxtRnD] Added MultivariateGaussian namedtuple  and Arraybuffer in trainGaussianMixture
fa0a142 [FlytxtRnD] New line added
d5b36ab [FlytxtRnD] Changed argument names to lowercase
ac134f1 [FlytxtRnD] Merge branch 'PythonGmmWrapper' of https://github.com/FlytxtRnD/spark into PythonGmmWrapper
6671ea1 [FlytxtRnD] Added mllib/stat/distribution.py
3aee84b [FlytxtRnD] Fixed style issues
2e9f12a [FlytxtRnD] Added mllib/stat/distribution.py and fixed style issues
b22532c [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper
2e14d82 [FlytxtRnD] Incorporate MultivariateGaussian instances in GaussianMixtureModel
05767c7 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper
3464d19 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper
c1d4c71 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'origin/PythonGmmWrapper' into PythonGmmWrapper
426d130 [FlytxtRnD] Added random seed parameter
332bad1 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper
f82750b [FlytxtRnD] Fixed style issues
5c83825 [FlytxtRnD] Split input file with space delimiter
fda60f3 [FlytxtRnD] Python API for Gaussian Mixture Model
---
 .../python/mllib/gaussian_mixture_model.py    | 65 +++++++++++++
 .../mllib/api/python/PythonMLLibAPI.scala     | 56 ++++++++++-
 python/pyspark/mllib/clustering.py            | 92 ++++++++++++++++++-
 python/pyspark/mllib/stat/__init__.py         |  3 +-
 python/pyspark/mllib/stat/distribution.py     | 31 +++++++
 python/pyspark/mllib/tests.py                 | 26 ++++++
 6 files changed, 267 insertions(+), 6 deletions(-)
 create mode 100644 examples/src/main/python/mllib/gaussian_mixture_model.py
 create mode 100644 python/pyspark/mllib/stat/distribution.py

diff --git a/examples/src/main/python/mllib/gaussian_mixture_model.py b/examples/src/main/python/mllib/gaussian_mixture_model.py
new file mode 100644
index 0000000000..a2cd626c9f
--- /dev/null
+++ b/examples/src/main/python/mllib/gaussian_mixture_model.py
@@ -0,0 +1,65 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A Gaussian Mixture Model clustering program using MLlib.
+"""
+import sys
+import random
+import argparse
+import numpy as np
+
+from pyspark import SparkConf, SparkContext
+from pyspark.mllib.clustering import GaussianMixture
+
+
+def parseVector(line):
+    return np.array([float(x) for x in line.split(' ')])
+
+
+if __name__ == "__main__":
+    """
+    Parameters
+    ----------
+    :param inputFile:        Input file path which contains data points
+    :param k:                Number of mixture components
+    :param convergenceTol:   Convergence threshold. Default to 1e-3
+    :param maxIterations:    Number of EM iterations to perform. Default to 100
+    :param seed:             Random seed
+    """
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument('inputFile', help='Input File')
+    parser.add_argument('k', type=int, help='Number of clusters')
+    parser.add_argument('--convergenceTol', default=1e-3, type=float, help='convergence threshold')
+    parser.add_argument('--maxIterations', default=100, type=int, help='Number of iterations')
+    parser.add_argument('--seed', default=random.getrandbits(19),
+                        type=long, help='Random seed')
+    args = parser.parse_args()
+
+    conf = SparkConf().setAppName("GMM")
+    sc = SparkContext(conf=conf)
+
+    lines = sc.textFile(args.inputFile)
+    data = lines.map(parseVector)
+    model = GaussianMixture.train(data, args.k, args.convergenceTol,
+                                  args.maxIterations, args.seed)
+    for i in range(args.k):
+        print ("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu,
+               "sigma = ", model.gaussians[i].sigma.toArray())
+    print ("Cluster labels (first 100): ", model.predict(data).take(100))
+    sc.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index a66d6f0cf2..980980593d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -22,6 +22,7 @@ import java.nio.{ByteBuffer, ByteOrder}
 import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
 import scala.language.existentials
 import scala.reflect.ClassTag
 
@@ -40,6 +41,7 @@ import org.apache.spark.mllib.recommendation._
 import org.apache.spark.mllib.regression._
 import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
 import org.apache.spark.mllib.stat.correlation.CorrelationNames
+import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
 import org.apache.spark.mllib.stat.test.ChiSqTestResult
 import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree}
 import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy}
@@ -260,7 +262,7 @@ class PythonMLLibAPI extends Serializable {
   }
 
   /**
-   * Java stub for Python mllib KMeans.train()
+   * Java stub for Python mllib KMeans.run()
    */
   def trainKMeansModel(
       data: JavaRDD[Vector],
@@ -284,6 +286,58 @@ class PythonMLLibAPI extends Serializable {
     }
   }
 
+  /**
+   * Java stub for Python mllib GaussianMixture.run()
+   * Returns a list containing weights, mean and covariance of each mixture component.
+   */
+  def trainGaussianMixture(
+      data: JavaRDD[Vector], 
+      k: Int, 
+      convergenceTol: Double, 
+      maxIterations: Int,
+      seed: Long): JList[Object] = {
+    val gmmAlg = new GaussianMixture()
+      .setK(k)
+      .setConvergenceTol(convergenceTol)
+      .setMaxIterations(maxIterations)
+
+    if (seed != null) gmmAlg.setSeed(seed)
+
+    try {
+      val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
+      var wt = ArrayBuffer.empty[Double]
+      var mu = ArrayBuffer.empty[Vector]      
+      var sigma = ArrayBuffer.empty[Matrix]
+      for (i <- 0 until model.k) {
+          wt += model.weights(i)
+          mu += model.gaussians(i).mu
+          sigma += model.gaussians(i).sigma
+      }    
+      List(wt.toArray, mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
+    } finally {
+      data.rdd.unpersist(blocking = false)
+    }
+  }
+
+  /**
+   * Java stub for Python mllib GaussianMixtureModel.predictSoft()
+   */
+  def predictSoftGMM(
+      data: JavaRDD[Vector],
+      wt: Object,
+      mu: Array[Object],
+      si: Array[Object]):  RDD[Array[Double]]  = {
+
+      val weight = wt.asInstanceOf[Array[Double]]
+      val mean = mu.map(_.asInstanceOf[DenseVector])
+      val sigma = si.map(_.asInstanceOf[DenseMatrix])
+      val gaussians = Array.tabulate(weight.length){
+        i => new MultivariateGaussian(mean(i), sigma(i))
+      }      
+      val model = new GaussianMixtureModel(weight, gaussians)
+      model.predictSoft(data)
+  }
+
   /**
    * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python
    */
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 6b713aa393..f6b97abb17 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -15,19 +15,22 @@
 # limitations under the License.
 #
 
+from numpy import array
+
+from pyspark import RDD
 from pyspark import SparkContext
 from pyspark.mllib.common import callMLlibFunc, callJavaFunc
-from pyspark.mllib.linalg import SparseVector, _convert_to_vector
+from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector
+from pyspark.mllib.stat.distribution import MultivariateGaussian
 
-__all__ = ['KMeansModel', 'KMeans']
+__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture']
 
 
 class KMeansModel(object):
 
     """A clustering model derived from the k-means method.
 
-    >>> from numpy import array
-    >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2)
+    >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4, 2)
     >>> model = KMeans.train(
     ...     sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random")
     >>> model.predict(array([0.0, 0.0])) == model.predict(array([1.0, 1.0]))
@@ -86,6 +89,87 @@ class KMeans(object):
         return KMeansModel([c.toArray() for c in centers])
 
 
+class GaussianMixtureModel(object):
+
+    """A clustering model derived from the Gaussian Mixture Model method.
+
+    >>> clusterdata_1 =  sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
+    ...                                         0.9,0.8,0.75,0.935,
+    ...                                        -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2))
+    >>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001,
+    ...                                 maxIterations=50, seed=10)
+    >>> labels = model.predict(clusterdata_1).collect()
+    >>> labels[0]==labels[1]
+    False
+    >>> labels[1]==labels[2]
+    True
+    >>> labels[4]==labels[5]
+    True
+    >>> clusterdata_2 =  sc.parallelize(array([-5.1971, -2.5359, -3.8220,
+    ...                                        -5.2211, -5.0602,  4.7118,
+    ...                                         6.8989, 3.4592,  4.6322,
+    ...                                         5.7048,  4.6567, 5.5026,
+    ...                                         4.5605,  5.2043,  6.2734]).reshape(5, 3))
+    >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001,
+    ...                                 maxIterations=150, seed=10)
+    >>> labels = model.predict(clusterdata_2).collect()
+    >>> labels[0]==labels[1]==labels[2]
+    True
+    >>> labels[3]==labels[4]
+    True
+    """
+
+    def __init__(self, weights, gaussians):
+        self.weights = weights
+        self.gaussians = gaussians
+        self.k = len(self.weights)
+
+    def predict(self, x):
+        """
+        Find the cluster to which the points in 'x' has maximum membership
+        in this model.
+
+        :param x:    RDD of data points.
+        :return:     cluster_labels. RDD of cluster labels.
+        """
+        if isinstance(x, RDD):
+            cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
+            return cluster_labels
+
+    def predictSoft(self, x):
+        """
+        Find the membership of each point in 'x' to all mixture components.
+
+        :param x:    RDD of data points.
+        :return:     membership_matrix. RDD of array of double values.
+        """
+        if isinstance(x, RDD):
+            means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
+            membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
+                                              self.weights, means, sigmas)
+            return membership_matrix
+
+
+class GaussianMixture(object):
+    """
+    Estimate model parameters with the expectation-maximization algorithm.
+
+    :param data:            RDD of data points
+    :param k:               Number of components
+    :param convergenceTol:  Threshold value to check the convergence criteria. Defaults to 1e-3
+    :param maxIterations:   Number of iterations. Default to 100
+    :param seed:            Random Seed
+    """
+    @classmethod
+    def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None):
+        """Train a Gaussian Mixture clustering model."""
+        weight, mu, sigma = callMLlibFunc("trainGaussianMixture",
+                                          rdd.map(_convert_to_vector), k,
+                                          convergenceTol, maxIterations, seed)
+        mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)]
+        return GaussianMixtureModel(weight, mvg_obj)
+
+
 def _test():
     import doctest
     globs = globals().copy()
diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py
index 799d260c09..b686d955a0 100644
--- a/python/pyspark/mllib/stat/__init__.py
+++ b/python/pyspark/mllib/stat/__init__.py
@@ -20,5 +20,6 @@ Python package for statistical functions in MLlib.
 """
 
 from pyspark.mllib.stat._statistics import *
+from pyspark.mllib.stat.distribution import MultivariateGaussian
 
-__all__ = ["Statistics", "MultivariateStatisticalSummary"]
+__all__ = ["Statistics", "MultivariateStatisticalSummary", "MultivariateGaussian"]
diff --git a/python/pyspark/mllib/stat/distribution.py b/python/pyspark/mllib/stat/distribution.py
new file mode 100644
index 0000000000..07792e1532
--- /dev/null
+++ b/python/pyspark/mllib/stat/distribution.py
@@ -0,0 +1,31 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from collections import namedtuple
+
+__all__ = ['MultivariateGaussian']
+
+
+class MultivariateGaussian(namedtuple('MultivariateGaussian', ['mu', 'sigma'])):
+
+    """ Represents a (mu, sigma) tuple
+    >>> m = MultivariateGaussian(Vectors.dense([11,12]),DenseMatrix(2, 2, (1.0, 3.0, 5.0, 2.0)))
+    >>> (m.mu, m.sigma.toArray())
+    (DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]]))
+    >>> (m[0], m[1])
+    (DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]]))
+    """
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 61e0cf5d90..42aa228737 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -167,6 +167,32 @@ class ListTests(PySparkTestCase):
             # TODO: Allow small numeric difference.
             self.assertTrue(array_equal(c1, c2))
 
+    def test_gmm(self):
+        from pyspark.mllib.clustering import GaussianMixture
+        data = self.sc.parallelize([
+            [1, 2],
+            [8, 9],
+            [-4, -3],
+            [-6, -7],
+        ])
+        clusters = GaussianMixture.train(data, 2, convergenceTol=0.001,
+                                         maxIterations=100, seed=56)
+        labels = clusters.predict(data).collect()
+        self.assertEquals(labels[0], labels[1])
+        self.assertEquals(labels[2], labels[3])
+
+    def test_gmm_deterministic(self):
+        from pyspark.mllib.clustering import GaussianMixture
+        x = range(0, 100, 10)
+        y = range(0, 100, 10)
+        data = self.sc.parallelize([[a, b] for a, b in zip(x, y)])
+        clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001,
+                                          maxIterations=100, seed=63)
+        clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001,
+                                          maxIterations=100, seed=63)
+        for c1, c2 in zip(clusters1.weights, clusters2.weights):
+            self.assertEquals(round(c1, 7), round(c2, 7))
+
     def test_classification(self):
         from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
         from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
-- 
GitLab