Skip to content
Snippets Groups Projects
Commit f86a89a2 authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Xiangrui Meng
Browse files

[SPARK-5714][Mllib] Refactor initial step of LDA to remove redundant operations

The `initialState` of LDA performs several RDD operations that looks redundant. This pr tries to simplify these operations.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #4501 from viirya/sim_lda and squashes the following commits:

4870fe4 [Liang-Chi Hsieh] For comments.
9af1487 [Liang-Chi Hsieh] Refactor initial step of LDA to remove redundant operations.
parent b8f88d32
No related branches found
No related tags found
No related merge requests found
...@@ -450,34 +450,23 @@ private[clustering] object LDA { ...@@ -450,34 +450,23 @@ private[clustering] object LDA {
// Create vertices. // Create vertices.
// Initially, we use random soft assignments of tokens to topics (random gamma). // Initially, we use random soft assignments of tokens to topics (random gamma).
val edgesWithGamma: RDD[(Edge[TokenCount], TopicCounts)] = def createVertices(): RDD[(VertexId, TopicCounts)] = {
edges.mapPartitionsWithIndex { case (partIndex, partEdges) => val verticesTMP: RDD[(VertexId, TopicCounts)] =
val random = new Random(partIndex + randomSeed) edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
partEdges.map { edge => val random = new Random(partIndex + randomSeed)
// Create a random gamma_{wjk} partEdges.flatMap { edge =>
(edge, normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)) val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)
val sum = gamma * edge.attr
Seq((edge.srcId, sum), (edge.dstId, sum))
}
} }
} verticesTMP.reduceByKey(_ + _)
def createVertices(sendToWhere: Edge[TokenCount] => VertexId): RDD[(VertexId, TopicCounts)] = {
val verticesTMP: RDD[(VertexId, (TokenCount, TopicCounts))] =
edgesWithGamma.map { case (edge, gamma: TopicCounts) =>
(sendToWhere(edge), (edge.attr, gamma))
}
verticesTMP.aggregateByKey(BDV.zeros[Double](k))(
(sum, t) => {
brzAxpy(t._1, t._2, sum)
sum
},
(sum0, sum1) => {
sum0 += sum1
}
)
} }
val docVertices = createVertices(_.srcId)
val termVertices = createVertices(_.dstId) val docTermVertices = createVertices()
// Partition such that edges are grouped by document // Partition such that edges are grouped by document
val graph = Graph(docVertices ++ termVertices, edges) val graph = Graph(docTermVertices, edges)
.partitionBy(PartitionStrategy.EdgePartition1D) .partitionBy(PartitionStrategy.EdgePartition1D)
new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval) new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment