diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 92e05815d6a3d3a64d56c165cde1802c3a389229..830510b1698d4a3ac71ff18a8c7782320ea341cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.clustering +import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter} import org.apache.spark.ml.param._ +import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, @@ -322,7 +323,7 @@ sealed abstract class LDAModel private[ml] ( @Since("1.6.0") override val uid: String, @Since("1.6.0") val vocabSize: Int, @Since("1.6.0") @transient protected val sqlContext: SQLContext) - extends Model[LDAModel] with LDAParams with Logging { + extends Model[LDAModel] with LDAParams with Logging with MLWritable { // NOTE to developers: // This abstraction should contain all important functionality for basic LDA usage. @@ -486,6 +487,64 @@ class LocalLDAModel private[ml] ( @Since("1.6.0") override def isDistributed: Boolean = false + + @Since("1.6.0") + override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this) +} + + +@Since("1.6.0") +object LocalLDAModel extends MLReadable[LocalLDAModel] { + + private[LocalLDAModel] + class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter { + + private case class Data( + vocabSize: Int, + topicsMatrix: Matrix, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val oldModel = instance.oldLocalModel + val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration, + oldModel.topicConcentration, oldModel.gammaShape) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class LocalLDAModelReader extends MLReader[LocalLDAModel] { + + private val className = classOf[LocalLDAModel].getName + + override def load(path: String): LocalLDAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration", + "gammaShape") + .head() + val vocabSize = data.getAs[Int](0) + val topicsMatrix = data.getAs[Matrix](1) + val docConcentration = data.getAs[Vector](2) + val topicConcentration = data.getAs[Double](3) + val gammaShape = data.getAs[Double](4) + val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration, + gammaShape) + val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[LocalLDAModel] = new LocalLDAModelReader + + @Since("1.6.0") + override def load(path: String): LocalLDAModel = super.load(path) } @@ -562,6 +621,45 @@ class DistributedLDAModel private[ml] ( */ @Since("1.6.0") lazy val logPrior: Double = oldDistributedModel.logPrior + + @Since("1.6.0") + override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this) +} + + +@Since("1.6.0") +object DistributedLDAModel extends MLReadable[DistributedLDAModel] { + + private[DistributedLDAModel] + class DistributedWriter(instance: DistributedLDAModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val modelPath = new Path(path, "oldModel").toString + instance.oldDistributedModel.save(sc, modelPath) + } + } + + private class DistributedLDAModelReader extends MLReader[DistributedLDAModel] { + + private val className = classOf[DistributedLDAModel].getName + + override def load(path: String): DistributedLDAModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val modelPath = new Path(path, "oldModel").toString + val oldModel = OldDistributedLDAModel.load(sc, modelPath) + val model = new DistributedLDAModel( + metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("1.6.0") + override def read: MLReader[DistributedLDAModel] = new DistributedLDAModelReader + + @Since("1.6.0") + override def load(path: String): DistributedLDAModel = super.load(path) } @@ -593,7 +691,8 @@ class DistributedLDAModel private[ml] ( @Since("1.6.0") @Experimental class LDA @Since("1.6.0") ( - @Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with LDAParams { + @Since("1.6.0") override val uid: String) + extends Estimator[LDAModel] with LDAParams with DefaultParamsWritable { @Since("1.6.0") def this() = this(Identifiable.randomUID("lda")) @@ -695,7 +794,7 @@ class LDA @Since("1.6.0") ( } -private[clustering] object LDA { +private[clustering] object LDA extends DefaultParamsReadable[LDA] { /** Get dataset for spark.mllib LDA */ def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = { @@ -706,4 +805,7 @@ private[clustering] object LDA { (docId, features) } } + + @Since("1.6.0") + override def load(path: String): LDA = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index cd520f09bd4664d102c61ecc61bf074cd076154a..7384d065a2ea8ffc113b95933617431e68a7ea9d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -187,11 +187,11 @@ abstract class LDAModel private[clustering] extends Saveable { * @param topics Inferred topics (vocabSize x k matrix). */ @Since("1.3.0") -class LocalLDAModel private[clustering] ( +class LocalLDAModel private[spark] ( @Since("1.3.0") val topics: Matrix, @Since("1.5.0") override val docConcentration: Vector, @Since("1.5.0") override val topicConcentration: Double, - override protected[clustering] val gammaShape: Double = 100) + override protected[spark] val gammaShape: Double = 100) extends LDAModel with Serializable { @Since("1.3.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index b634d31cc34f025b0fa46c79cf9fcc27b4c7ef8b..97dbfd9a4314a5db98895fbf2ac17be4f732aef5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} @@ -39,10 +40,24 @@ object LDASuite { }.map(v => new TestRow(v)) sql.createDataFrame(rdd) } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "k" -> 3, + "maxIter" -> 2, + "checkpointInterval" -> 30, + "learningOffset" -> 1023.0, + "learningDecay" -> 0.52, + "subsamplingRate" -> 0.051 + ) } -class LDASuite extends SparkFunSuite with MLlibTestSparkContext { +class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { val k: Int = 5 val vocabSize: Int = 30 @@ -218,4 +233,29 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { val lp = model.logPrior assert(lp <= 0.0 && lp != Double.NegativeInfinity) } + + test("read/write LocalLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } + val lda = new LDA() + testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) + } + + test("read/write DistributedLDAModel") { + def checkModelData(model: LDAModel, model2: LDAModel): Unit = { + assert(model.vocabSize === model2.vocabSize) + assert(Vectors.dense(model.topicsMatrix.toArray) ~== + Vectors.dense(model2.topicsMatrix.toArray) absTol 1e-6) + assert(Vectors.dense(model.getDocConcentration) ~== + Vectors.dense(model2.getDocConcentration) absTol 1e-6) + } + val lda = new LDA() + testEstimatorAndModelReadWrite(lda, dataset, + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) + } }