diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 920b57756b6254df90dbdb8761acdcd3c14964c1..31c1d520fd6595563bfd54bfb7c9b8ec8c805986 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -283,12 +283,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] { */ @Experimental class DistributedLDAModel private ( - private val graph: Graph[LDA.TopicCounts, LDA.TokenCount], - private val globalTopicTotals: LDA.TopicCounts, + private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount], + private[clustering] val globalTopicTotals: LDA.TopicCounts, val k: Int, val vocabSize: Int, - private val docConcentration: Double, - private val topicConcentration: Double, + private[clustering] val docConcentration: Double, + private[clustering] val topicConcentration: Double, private[spark] val iterationTimes: Array[Double]) extends LDAModel { import LDA._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index da70d9bd7c7902f75836ab8065f34325d6cfba4f..376a87f0511b413ef3f02b817ed70dcb9c5ab692 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.SparkFunSuite +import org.apache.spark.graphx.Edge import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -318,6 +319,20 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(distributedModel.k === sameDistributedModel.k) assert(distributedModel.vocabSize === sameDistributedModel.vocabSize) assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes) + assert(distributedModel.docConcentration === sameDistributedModel.docConcentration) + assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration) + assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals) + + val graph = distributedModel.graph + val sameGraph = sameDistributedModel.graph + assert(graph.vertices.sortByKey().collect() === sameGraph.vertices.sortByKey().collect()) + val edge = graph.edges.map { + case Edge(sid: Long, did: Long, nos: Double) => (sid, did, nos) + }.sortBy(x => (x._1, x._2)).collect() + val sameEdge = sameGraph.edges.map { + case Edge(sid: Long, did: Long, nos: Double) => (sid, did, nos) + }.sortBy(x => (x._1, x._2)).collect() + assert(edge === sameEdge) } finally { Utils.deleteRecursively(tempDir1) Utils.deleteRecursively(tempDir2)