diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index fcb9a3643cc4895d1079a0ce886ec01750170d91..9b5c155b0a805da80c28fdeeae15c063982e889f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -43,15 +43,19 @@ class PowerIterationClusteringModel( * * @param k Number of clusters. * @param maxIterations Maximum number of iterations of the PIC algorithm. + * @param initMode Initialization mode. */ class PowerIterationClustering private[clustering] ( private var k: Int, - private var maxIterations: Int) extends Serializable { + private var maxIterations: Int, + private var initMode: String) extends Serializable { import org.apache.spark.mllib.clustering.PowerIterationClustering._ - /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100}. */ - def this() = this(k = 2, maxIterations = 100) + /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100, + * initMode: "random"}. + */ + def this() = this(k = 2, maxIterations = 100, initMode = "random") /** * Set the number of clusters. @@ -69,6 +73,18 @@ class PowerIterationClustering private[clustering] ( this } + /** + * Set the initialization mode. This can be either "random" to use a random vector + * as vertex properties, or "degree" to use normalized sum similarities. Default: random. + */ + def setInitializationMode(mode: String): this.type = { + this.initMode = mode match { + case "random" | "degree" => mode + case _ => throw new IllegalArgumentException("Invalid initialization mode: " + mode) + } + this + } + /** * Run the PIC algorithm. * @@ -82,7 +98,10 @@ class PowerIterationClustering private[clustering] ( */ def run(similarities: RDD[(Long, Long, Double)]): PowerIterationClusteringModel = { val w = normalize(similarities) - val w0 = randomInit(w) + val w0 = initMode match { + case "random" => randomInit(w) + case "degree" => initDegreeVector(w) + } pic(w0) } @@ -148,6 +167,20 @@ private[clustering] object PowerIterationClustering extends Logging { GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) } + /** + * Generates the degree vector as the vertex properties (v0) to start power iteration. + * It is not exactly the node degrees but just the normalized sum similarities. Call it + * as degree vector because it is used in the PIC paper. + * + * @param g a graph representing the normalized affinity matrix (W) + * @return a graph with edges representing W and vertices representing the degree vector + */ + def initDegreeVector(g: Graph[Double, Double]): Graph[Double, Double] = { + val sum = g.vertices.values.sum() + val v0 = g.vertices.mapValues(_ / sum) + GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) + } + /** * Runs power iteration. * @param g input graph with edges representing the normalized affinity matrix (W) and vertices diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 2bae465d392aa58c90270a985c7f94e8c65b6008..03ecd9ca730be901731606b8a69a89dd95769629 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -55,6 +55,16 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext predictions(c) += i } assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + + val model2 = new PowerIterationClustering() + .setK(2) + .setInitializationMode("degree") + .run(sc.parallelize(similarities, 2)) + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + model2.assignments.collect().foreach { case (i, c) => + predictions2(c) += i + } + assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) } test("normalize and powerIter") {