diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index e2c6aca553c1cd1be28814e586cc21b31e22a114..ae324f86fe6d18e498f775da01241c5dc768182b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -28,6 +28,7 @@ import org.apache.spark.graphx._ import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: @@ -472,12 +473,13 @@ final class OnlineLDAOptimizer extends LDAOptimizer { gammaPart = gammad :: gammaPart } Iterator((stat, gammaPart)) - } + }.persist(StorageLevel.MEMORY_AND_DISK) val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))( _ += _, _ += _) - expElogbetaBc.unpersist() val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) + stats.unpersist() + expElogbetaBc.destroy(false) val batchResult = statsSum :* expElogbeta.t // Note that this is an optimization to avoid batch.count