From 764ca18037b6b1884fbc4be9a011714a81495020 Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Mon, 22 Feb 2016 23:54:21 -0800
Subject: [PATCH] [SPARK-13355][MLLIB] replace GraphImpl.fromExistingRDDs by
 Graph.apply

`GraphImpl.fromExistingRDDs` expects preprocessed vertex RDD as input. We call it in LDA without validating this requirement. So it might introduce errors. Replacing it by `Graph.apply` would be safer and more proper because it is a public API. The tests still pass. So maybe it is safe to use `fromExistingRDDs` here (though it doesn't seem so based on the implementation) or the test cases are special. jkbradley ankurdave

Author: Xiangrui Meng <meng@databricks.com>

Closes #11226 from mengxr/SPARK-13355.
---
 .../scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 7a41f74191..7491ab0d51 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -25,7 +25,6 @@ import breeze.stats.distributions.{Gamma, RandBasis}
 
 import org.apache.spark.annotation.{DeveloperApi, Since}
 import org.apache.spark.graphx._
-import org.apache.spark.graphx.impl.GraphImpl
 import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
 import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors}
 import org.apache.spark.rdd.RDD
@@ -188,7 +187,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
       graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg)
         .mapValues(_._2)
     // Update the vertex descriptors with the new counts.
-    val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
+    val newGraph = Graph(docTopicDistributions, graph.edges)
     graph = newGraph
     graphCheckpointer.update(newGraph)
     globalTopicTotals = computeGlobalTopicTotals()
-- 
GitLab