Skip to content
Snippets Groups Projects
Commit 4fc4d036 authored by MechCoder's avatar MechCoder Committed by Xiangrui Meng
Browse files

[SPARK-5987] [MLlib] Save/load for GaussianMixtureModels

Should be self explanatory.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #4986 from MechCoder/spark-5987 and squashes the following commits:

7d2cd56 [MechCoder] Iterate over dataframe in a better way
e7a14cb [MechCoder] Minor
33c84f9 [MechCoder] Store as Array[Data] instead of Data[Array]
505bd57 [MechCoder] Rebased over master and used MatrixUDT
7422bb4 [MechCoder] Store sigmas as Array[Double] instead of Array[Array[Double]]
b9794e4 [MechCoder] Minor
cb77095 [MechCoder] [SPARK-5987] Save/load for GaussianMixtureModels
parent 43533738
No related branches found
No related tags found
No related merge requests found
......@@ -173,6 +173,7 @@ to the algorithm. We then output the parameters of the mixture model.
{% highlight scala %}
import org.apache.spark.mllib.clustering.GaussianMixture
import org.apache.spark.mllib.clustering.GaussianMixtureModel
import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
......@@ -182,6 +183,10 @@ val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble)))
// Cluster the data into two classes using GaussianMixture
val gmm = new GaussianMixture().setK(2).run(parsedData)
// Save and load model
gmm.save(sc, "myGMMModel")
val sameModel = GaussianMixtureModel.load(sc, "myGMMModel")
// output parameters of max-likelihood model
for (i <- 0 until gmm.k) {
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
......@@ -231,6 +236,9 @@ public class GaussianMixtureExample {
// Cluster the data into two classes using GaussianMixture
GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd());
// Save and load GaussianMixtureModel
gmm.save(sc, "myGMMModel")
GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel")
// Output the parameters of the mixture model
for(int j=0; j<gmm.k(); j++) {
System.out.println("weight=%f\nmu=%s\nsigma=\n%s\n",
......
......@@ -19,11 +19,17 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseVector => BreezeVector}
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, Row}
/**
* :: Experimental ::
......@@ -41,10 +47,16 @@ import org.apache.spark.rdd.RDD
@Experimental
class GaussianMixtureModel(
val weights: Array[Double],
val gaussians: Array[MultivariateGaussian]) extends Serializable {
val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{
require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")
override protected def formatVersion = "1.0"
override def save(sc: SparkContext, path: String): Unit = {
GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians)
}
/** Number of gaussians in mixture */
def k: Int = weights.length
......@@ -83,5 +95,79 @@ class GaussianMixtureModel(
p(i) /= pSum
}
p
}
}
}
@Experimental
object GaussianMixtureModel extends Loader[GaussianMixtureModel] {
private object SaveLoadV1_0 {
case class Data(weight: Double, mu: Vector, sigma: Matrix)
val formatVersionV1_0 = "1.0"
val classNameV1_0 = "org.apache.spark.mllib.clustering.GaussianMixtureModel"
def save(
sc: SparkContext,
path: String,
weights: Array[Double],
gaussians: Array[MultivariateGaussian]): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
// Create JSON metadata.
val metadata = compact(render
(("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ ("k" -> weights.length)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
// Create Parquet data.
val dataArray = Array.tabulate(weights.length) { i =>
Data(weights(i), gaussians(i).mu, gaussians(i).sigma)
}
sc.parallelize(dataArray, 1).toDF().saveAsParquetFile(Loader.dataPath(path))
}
def load(sc: SparkContext, path: String): GaussianMixtureModel = {
val dataPath = Loader.dataPath(path)
val sqlContext = new SQLContext(sc)
val dataFrame = sqlContext.parquetFile(dataPath)
val dataArray = dataFrame.select("weight", "mu", "sigma").collect()
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[Data](dataFrame.schema)
val (weights, gaussians) = dataArray.map {
case Row(weight: Double, mu: Vector, sigma: Matrix) =>
(weight, new MultivariateGaussian(mu, sigma))
}.unzip
return new GaussianMixtureModel(weights.toArray, gaussians.toArray)
}
}
override def load(sc: SparkContext, path: String) : GaussianMixtureModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
implicit val formats = DefaultFormats
val k = (metadata \ "k").extract[Int]
val classNameV1_0 = SaveLoadV1_0.classNameV1_0
(loadedClassName, version) match {
case (classNameV1_0, "1.0") => {
val model = SaveLoadV1_0.load(sc, path)
require(model.weights.length == k,
s"GaussianMixtureModel requires weights of length $k " +
s"got weights of length ${model.weights.length}")
require(model.gaussians.length == k,
s"GaussianMixtureModel requires gaussians of length $k" +
s"got gaussians of length ${model.gaussians.length}")
model
}
case _ => throw new Exception(
s"GaussianMixtureModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}
}
......@@ -23,6 +23,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Matrices}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
test("single cluster") {
......@@ -48,13 +49,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
}
test("two clusters") {
val data = sc.parallelize(Array(
Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
))
val data = sc.parallelize(GaussianTestData.data)
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
......@@ -105,14 +100,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
}
test("two clusters with sparse data") {
val data = sc.parallelize(Array(
Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
))
val data = sc.parallelize(GaussianTestData.data)
val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray))
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
......@@ -138,4 +126,38 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
}
test("model save / load") {
val data = sc.parallelize(GaussianTestData.data)
val gmm = new GaussianMixture().setK(2).setSeed(0).run(data)
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
try {
gmm.save(sc, path)
// TODO: GaussianMixtureModel should implement equals/hashcode directly.
val sameModel = GaussianMixtureModel.load(sc, path)
assert(sameModel.k === gmm.k)
(0 until sameModel.k).foreach { i =>
assert(sameModel.gaussians(i).mu === gmm.gaussians(i).mu)
assert(sameModel.gaussians(i).sigma === gmm.gaussians(i).sigma)
}
} finally {
Utils.deleteRecursively(tempDir)
}
}
object GaussianTestData {
val data = Array(
Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment