Skip to content
Snippets Groups Projects
Commit 1e340c3a authored by Xusen Yin's avatar Xusen Yin Committed by Xiangrui Meng
Browse files

[SPARK-5988][MLlib] add save/load for PowerIterationClusteringModel

See JIRA issue [SPARK-5988](https://issues.apache.org/jira/browse/SPARK-5988).

Author: Xusen Yin <yinxusen@gmail.com>

Closes #5450 from yinxusen/SPARK-5988 and squashes the following commits:

cb1ecfa [Xusen Yin] change Assignment into case class
b1dd24c [Xusen Yin] add test suite
63c3923 [Xusen Yin] add save load for power iteration clustering
parent 6cc5b3ed
No related branches found
No related tags found
No related merge requests found
......@@ -17,15 +17,20 @@
package org.apache.spark.mllib.clustering
import org.apache.spark.{Logging, SparkException}
import org.json4s.JsonDSL._
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.{Logging, SparkContext, SparkException}
/**
* :: Experimental ::
......@@ -38,7 +43,60 @@ import org.apache.spark.util.random.XORShiftRandom
@Experimental
class PowerIterationClusteringModel(
val k: Int,
val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable
val assignments: RDD[PowerIterationClustering.Assignment]) extends Saveable with Serializable {
override def save(sc: SparkContext, path: String): Unit = {
PowerIterationClusteringModel.SaveLoadV1_0.save(sc, this, path)
}
override protected def formatVersion: String = "1.0"
}
object PowerIterationClusteringModel extends Loader[PowerIterationClusteringModel] {
override def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
PowerIterationClusteringModel.SaveLoadV1_0.load(sc, path)
}
private[clustering]
object SaveLoadV1_0 {
private val thisFormatVersion = "1.0"
private[clustering]
val thisClassName = "org.apache.spark.mllib.clustering.PowerIterationClusteringModel"
def save(sc: SparkContext, model: PowerIterationClusteringModel, path: String): Unit = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val dataRDD = model.assignments.toDF()
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}
def load(sc: SparkContext, path: String): PowerIterationClusteringModel = {
implicit val formats = DefaultFormats
val sqlContext = new SQLContext(sc)
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
val k = (metadata \ "k").extract[Int]
val assignments = sqlContext.parquetFile(Loader.dataPath(path))
Loader.checkSchema[PowerIterationClustering.Assignment](assignments.schema)
val assignmentsRDD = assignments.map {
case Row(id: Long, cluster: Int) => PowerIterationClustering.Assignment(id, cluster)
}
new PowerIterationClusteringModel(k, assignmentsRDD)
}
}
}
/**
* :: Experimental ::
......@@ -135,7 +193,7 @@ class PowerIterationClustering private[clustering] (
val v = powerIter(w, maxIterations)
val assignments = kMeans(v, k).mapPartitions({ iter =>
iter.map { case (id, cluster) =>
new Assignment(id, cluster)
Assignment(id, cluster)
}
}, preservesPartitioning = true)
new PowerIterationClusteringModel(k, assignments)
......@@ -152,7 +210,7 @@ object PowerIterationClustering extends Logging {
* @param cluster assigned cluster id
*/
@Experimental
class Assignment(val id: Long, val cluster: Int) extends Serializable
case class Assignment(id: Long, cluster: Int)
/**
* Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).
......
......@@ -18,12 +18,15 @@
package org.apache.spark.mllib.clustering
import scala.collection.mutable
import scala.util.Random
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext {
......@@ -110,4 +113,35 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
assert(x ~== u1(i.toInt) absTol 1e-14)
}
}
test("model save/load") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
val model = PowerIterationClusteringSuite.createModel(sc, 3, 10)
try {
model.save(sc, path)
val sameModel = PowerIterationClusteringModel.load(sc, path)
PowerIterationClusteringSuite.checkEqual(model, sameModel)
} finally {
Utils.deleteRecursively(tempDir)
}
}
}
object PowerIterationClusteringSuite extends FunSuite {
def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = {
val assignments = sc.parallelize(
(0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k))))
new PowerIterationClusteringModel(k, assignments)
}
def checkEqual(a: PowerIterationClusteringModel, b: PowerIterationClusteringModel): Unit = {
assert(a.k === b.k)
val aAssignments = a.assignments.map(x => (x.id, x.cluster))
val bAssignments = b.assignments.map(x => (x.id, x.cluster))
val unequalElements = aAssignments.join(bAssignments).filter {
case (id, (c1, c2)) => c1 != c2 }.count()
assert(unequalElements === 0L)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment