diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index c71eec5ce043789ab4a151e6de0c1d5fecee6d39..4404cffc292aaedfe372caf8cec379652231c639 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -25,6 +25,9 @@ exportMethods("glm", "fitted", "spark.naiveBayes", "spark.survreg", + "spark.lda", + "spark.posterior", + "spark.perplexity", "spark.isoreg", "spark.gaussianMixture") diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 06bb25d62d34d77b9b8b9d4de72c17bc87724274..fe04bcfc7d14df62414478eca59c9ff8397d8ee1 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1304,6 +1304,19 @@ setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("s #' @export setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) +#' @rdname spark.lda +#' @param ... Additional parameters to tune LDA. +#' @export +setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") }) + +#' @rdname spark.lda +#' @export +setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark.posterior") }) + +#' @rdname spark.lda +#' @export +setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) + #' @rdname spark.isoreg #' @export setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") }) @@ -1315,6 +1328,7 @@ setGeneric("spark.gaussianMixture", standardGeneric("spark.gaussianMixture") }) +#' write.ml #' @rdname write.ml #' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index db74046056a99ab8ba472f19ba61ff4364ae77bb..b9527410a98531e436b27129b454d7ca90331337 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -39,6 +39,13 @@ setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj")) #' @note NaiveBayesModel since 2.0.0 setClass("NaiveBayesModel", representation(jobj = "jobj")) +#' S4 class that represents an LDAModel +#' +#' @param jobj a Java object reference to the backing Scala LDAWrapper +#' @export +#' @note LDAModel since 2.1.0 +setClass("LDAModel", representation(jobj = "jobj")) + #' S4 class that represents a AFTSurvivalRegressionModel #' #' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper @@ -75,7 +82,7 @@ setClass("GaussianMixtureModel", representation(jobj = "jobj")) #' @name write.ml #' @export #' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture} -#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg} +#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.lda} #' @seealso \link{spark.isoreg} #' @seealso \link{read.ml} NULL @@ -315,6 +322,94 @@ setMethod("summary", signature(object = "NaiveBayesModel"), return(list(apriori = apriori, tables = tables)) }) +# Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda() + +#' @param newData A SparkDataFrame for testing +#' @return \code{spark.posterior} returns a SparkDataFrame containing posterior probabilities +#' vectors named "topicDistribution" +#' @rdname spark.lda +#' @aliases spark.posterior,LDAModel,SparkDataFrame-method +#' @export +#' @note spark.posterior(LDAModel) since 2.1.0 +setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"), + function(object, newData) { + return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) + }) + +# Returns the summary of a Latent Dirichlet Allocation model produced by \code{spark.lda} + +#' @param object A Latent Dirichlet Allocation model fitted by \code{spark.lda}. +#' @param maxTermsPerTopic Maximum number of terms to collect for each topic. Default value of 10. +#' @return \code{summary} returns a list containing +#' \item{\code{docConcentration}}{concentration parameter commonly named \code{alpha} for +#' the prior placed on documents distributions over topics \code{theta}} +#' \item{\code{topicConcentration}}{concentration parameter commonly named \code{beta} or +#' \code{eta} for the prior placed on topic distributions over terms} +#' \item{\code{logLikelihood}}{log likelihood of the entire corpus} +#' \item{\code{logPerplexity}}{log perplexity} +#' \item{\code{isDistributed}}{TRUE for distributed model while FALSE for local model} +#' \item{\code{vocabSize}}{number of terms in the corpus} +#' \item{\code{topics}}{top 10 terms and their weights of all topics} +#' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file +#' used as training set} +#' @rdname spark.lda +#' @aliases summary,LDAModel-method +#' @export +#' @note summary(LDAModel) since 2.1.0 +setMethod("summary", signature(object = "LDAModel"), + function(object, maxTermsPerTopic) { + maxTermsPerTopic <- as.integer(ifelse(missing(maxTermsPerTopic), 10, maxTermsPerTopic)) + jobj <- object@jobj + docConcentration <- callJMethod(jobj, "docConcentration") + topicConcentration <- callJMethod(jobj, "topicConcentration") + logLikelihood <- callJMethod(jobj, "logLikelihood") + logPerplexity <- callJMethod(jobj, "logPerplexity") + isDistributed <- callJMethod(jobj, "isDistributed") + vocabSize <- callJMethod(jobj, "vocabSize") + topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic)) + vocabulary <- callJMethod(jobj, "vocabulary") + return(list(docConcentration = unlist(docConcentration), + topicConcentration = topicConcentration, + logLikelihood = logLikelihood, logPerplexity = logPerplexity, + isDistributed = isDistributed, vocabSize = vocabSize, + topics = topics, + vocabulary = unlist(vocabulary))) + }) + +# Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda} + +#' @return \code{spark.perplexity} returns the log perplexity of given SparkDataFrame, or the log +#' perplexity of the training data if missing argument "data". +#' @rdname spark.lda +#' @aliases spark.perplexity,LDAModel-method +#' @export +#' @note spark.perplexity(LDAModel) since 2.1.0 +setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"), + function(object, data) { + return(ifelse(missing(data), callJMethod(object@jobj, "logPerplexity"), + callJMethod(object@jobj, "computeLogPerplexity", data@sdf))) + }) + +# Saves the Latent Dirichlet Allocation model to the input path. + +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.lda +#' @aliases write.ml,LDAModel,character-method +#' @export +#' @seealso \link{read.ml} +#' @note write.ml(LDAModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "LDAModel", path = "character"), + function(object, path, overwrite = FALSE) { + writer <- callJMethod(object@jobj, "write") + if (overwrite) { + writer <- callJMethod(writer, "overwrite") + } + invisible(callJMethod(writer, "save", path)) + }) + #' Isotonic Regression Model #' #' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg(). @@ -700,6 +795,8 @@ read.ml <- function(path) { return(new("GeneralizedLinearRegressionModel", jobj = jobj)) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) { return(new("KMeansModel", jobj = jobj)) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) { + return(new("LDAModel", jobj = jobj)) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) { return(new("IsotonicRegressionModel", jobj = jobj)) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) { @@ -751,6 +848,71 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula return(new("AFTSurvivalRegressionModel", jobj = jobj)) }) +#' Latent Dirichlet Allocation +#' +#' \code{spark.lda} fits a Latent Dirichlet Allocation model on a SparkDataFrame. Users can call +#' \code{summary} to get a summary of the fitted LDA model, \code{spark.posterior} to compute +#' posterior probabilities on new data, \code{spark.perplexity} to compute log perplexity on new +#' data and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' +#' @param data A SparkDataFrame for training +#' @param features Features column name, default "features". Either libSVM-format column or +#' character-format column is valid. +#' @param k Number of topics, default 10 +#' @param maxIter Maximum iterations, default 20 +#' @param optimizer Optimizer to train an LDA model, "online" or "em", default "online" +#' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in +#' each iteration of mini-batch gradient descent, in range (0, 1], default 0.05 +#' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for +#' the prior placed on topic distributions over terms, default -1 to set automatically on the +#' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size +#' numeric is accepted. +#' @param docConcentration concentration parameter (commonly named \code{alpha}) for the +#' prior placed on documents distributions over topics (\code{theta}), default -1 to set +#' automatically on the Spark side. Use \code{summary} to retrieve the effective +#' docConcentration. Only 1-size or \code{k}-size numeric is accepted. +#' @param customizedStopWords stopwords that need to be removed from the given corpus. Ignore the +#' parameter if libSVM-format column is used as the features column. +#' @param maxVocabSize maximum vocabulary size, default 1 << 18 +#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model +#' @rdname spark.lda +#' @aliases spark.lda,SparkDataFrame-method +#' @seealso topicmodels: \url{https://cran.r-project.org/web/packages/topicmodels/} +#' @export +#' @examples +#' \dontrun{ +#' text <- read.df("path/to/data", source = "libsvm") +#' model <- spark.lda(data = text, optimizer = "em") +#' +#' # get a summary of the model +#' summary(model) +#' +#' # compute posterior probabilities +#' posterior <- spark.posterior(model, df) +#' showDF(posterior) +#' +#' # compute perplexity +#' perplexity <- spark.perplexity(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.lda since 2.1.0 +setMethod("spark.lda", signature(data = "SparkDataFrame"), + function(data, features = "features", k = 10, maxIter = 20, optimizer = c("online", "em"), + subsamplingRate = 0.05, topicConcentration = -1, docConcentration = -1, + customizedStopWords = "", maxVocabSize = bitwShiftL(1, 18)) { + optimizer <- match.arg(optimizer) + jobj <- callJStatic("org.apache.spark.ml.r.LDAWrapper", "fit", data@sdf, features, + as.integer(k), as.integer(maxIter), optimizer, + as.numeric(subsamplingRate), topicConcentration, + as.array(docConcentration), as.array(customizedStopWords), + maxVocabSize) + return(new("LDAModel", jobj = jobj)) + }) # Returns a summary of the AFT survival regression model produced by spark.survreg, # similarly to R's summary(). @@ -891,4 +1053,4 @@ setMethod("summary", signature(object = "GaussianMixtureModel"), setMethod("predict", signature(object = "GaussianMixtureModel"), function(object, newData) { return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) - }) + }) \ No newline at end of file diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 96179864a88bf090ac1550cfead9ed69d0bd36fc..8c380fbf150f48c5f61eb96168408f2edf8fdd29 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -570,4 +570,91 @@ test_that("spark.gaussianMixture", { unlink(modelPath) }) +test_that("spark.lda with libsvm", { + text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm") + model <- spark.lda(text, optimizer = "em") + + stats <- summary(model, 10) + isDistributed <- stats$isDistributed + logLikelihood <- stats$logLikelihood + logPerplexity <- stats$logPerplexity + vocabSize <- stats$vocabSize + topics <- stats$topicTopTerms + weights <- stats$topicTopTermsWeights + vocabulary <- stats$vocabulary + + expect_false(isDistributed) + expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) + expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) + expect_equal(vocabSize, 11) + expect_true(is.null(vocabulary)) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_false(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_equal(vocabulary, stats2$vocabulary) + + unlink(modelPath) +}) + +test_that("spark.lda with text input", { + text <- read.text("data/mllib/sample_lda_data.txt") + model <- spark.lda(text, optimizer = "online", features = "value") + + stats <- summary(model) + isDistributed <- stats$isDistributed + logLikelihood <- stats$logLikelihood + logPerplexity <- stats$logPerplexity + vocabSize <- stats$vocabSize + topics <- stats$topicTopTerms + weights <- stats$topicTopTermsWeights + vocabulary <- stats$vocabulary + + expect_false(isDistributed) + expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) + expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) + expect_equal(vocabSize, 10) + expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"))) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_false(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_true(all.equal(vocabulary, stats2$vocabulary)) + + unlink(modelPath) +}) + +test_that("spark.posterior and spark.perplexity", { + text <- read.text("data/mllib/sample_lda_data.txt") + model <- spark.lda(text, features = "value", k = 3) + + # Assert perplexities are equal + stats <- summary(model) + logPerplexity <- spark.perplexity(model, text) + expect_equal(logPerplexity, stats$logPerplexity) + + # Assert the sum of every topic distribution is equal to 1 + posterior <- spark.posterior(model, text) + local.posterior <- collect(posterior)$topicDistribution + expect_equal(length(local.posterior), sum(unlist(local.posterior))) +}) + sparkR.session.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 034f2c3fa2fd935534d1cae02c5fee189e548925..b5a764b5863f1f893a577fc54e6990a37ef63b3b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -386,6 +386,10 @@ sealed abstract class LDAModel private[ml] ( @Since("1.6.0") protected def getModel: OldLDAModel + private[ml] def getEffectiveDocConcentration: Array[Double] = getModel.docConcentration.toArray + + private[ml] def getEffectiveTopicConcentration: Double = getModel.topicConcentration + /** * The features for LDA should be a [[Vector]] representing the word counts in a document. * The vector should be of length vocabSize, with counts for each term (word). diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala new file mode 100644 index 0000000000000000000000000000000000000000..cbe6a705007d1db2d7f9218472defd96bfcf3bbe --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkException +import org.apache.spark.ml.{Pipeline, PipelineModel, PipelineStage} +import org.apache.spark.ml.clustering.{LDA, LDAModel} +import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover} +import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.param.ParamPair +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.StringType + + +private[r] class LDAWrapper private ( + val pipeline: PipelineModel, + val logLikelihood: Double, + val logPerplexity: Double, + val vocabulary: Array[String]) extends MLWritable { + + import LDAWrapper._ + + private val lda: LDAModel = pipeline.stages.last.asInstanceOf[LDAModel] + private val preprocessor: PipelineModel = + new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", pipeline.stages.dropRight(1)) + + def transform(data: Dataset[_]): DataFrame = { + val vec2ary = udf { vec: Vector => vec.toArray } + val outputCol = lda.getTopicDistributionCol + val tempCol = s"${Identifiable.randomUID(outputCol)}" + val preprocessed = preprocessor.transform(data) + lda.transform(preprocessed, ParamPair(lda.topicDistributionCol, tempCol)) + .withColumn(outputCol, vec2ary(col(tempCol))) + .drop(TOKENIZER_COL, STOPWORDS_REMOVER_COL, COUNT_VECTOR_COL, tempCol) + } + + def computeLogPerplexity(data: Dataset[_]): Double = { + lda.logPerplexity(preprocessor.transform(data)) + } + + def topics(maxTermsPerTopic: Int): DataFrame = { + val topicIndices: DataFrame = lda.describeTopics(maxTermsPerTopic) + if (vocabulary.isEmpty || vocabulary.length < vocabSize) { + topicIndices + } else { + val index2term = udf { indices: mutable.WrappedArray[Int] => indices.map(i => vocabulary(i)) } + topicIndices + .select(col("topic"), index2term(col("termIndices")).as("term"), col("termWeights")) + } + } + + lazy val isDistributed: Boolean = lda.isDistributed + lazy val vocabSize: Int = lda.vocabSize + lazy val docConcentration: Array[Double] = lda.getEffectiveDocConcentration + lazy val topicConcentration: Double = lda.getEffectiveTopicConcentration + + override def write: MLWriter = new LDAWrapper.LDAWrapperWriter(this) +} + +private[r] object LDAWrapper extends MLReadable[LDAWrapper] { + + val TOKENIZER_COL = s"${Identifiable.randomUID("rawTokens")}" + val STOPWORDS_REMOVER_COL = s"${Identifiable.randomUID("tokens")}" + val COUNT_VECTOR_COL = s"${Identifiable.randomUID("features")}" + + private def getPreStages( + features: String, + customizedStopWords: Array[String], + maxVocabSize: Int): Array[PipelineStage] = { + val tokenizer = new RegexTokenizer() + .setInputCol(features) + .setOutputCol(TOKENIZER_COL) + val stopWordsRemover = new StopWordsRemover() + .setInputCol(TOKENIZER_COL) + .setOutputCol(STOPWORDS_REMOVER_COL) + stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords) + val countVectorizer = new CountVectorizer() + .setVocabSize(maxVocabSize) + .setInputCol(STOPWORDS_REMOVER_COL) + .setOutputCol(COUNT_VECTOR_COL) + + Array(tokenizer, stopWordsRemover, countVectorizer) + } + + def fit( + data: DataFrame, + features: String, + k: Int, + maxIter: Int, + optimizer: String, + subsamplingRate: Double, + topicConcentration: Double, + docConcentration: Array[Double], + customizedStopWords: Array[String], + maxVocabSize: Int): LDAWrapper = { + + val lda = new LDA() + .setK(k) + .setMaxIter(maxIter) + .setSubsamplingRate(subsamplingRate) + + val featureSchema = data.schema(features) + val stages = featureSchema.dataType match { + case d: StringType => + getPreStages(features, customizedStopWords, maxVocabSize) ++ + Array(lda.setFeaturesCol(COUNT_VECTOR_COL)) + case d: VectorUDT => + Array(lda.setFeaturesCol(features)) + case _ => + throw new SparkException( + s"Unsupported input features type of ${featureSchema.dataType.typeName}," + + s" only String type and Vector type are supported now.") + } + + if (topicConcentration != -1) { + lda.setTopicConcentration(topicConcentration) + } else { + // Auto-set topicConcentration + } + + if (docConcentration.length == 1) { + if (docConcentration.head != -1) { + lda.setDocConcentration(docConcentration.head) + } else { + // Auto-set docConcentration + } + } else { + lda.setDocConcentration(docConcentration) + } + + val pipeline = new Pipeline().setStages(stages) + val model = pipeline.fit(data) + + val vocabulary: Array[String] = featureSchema.dataType match { + case d: StringType => + val countVectorModel = model.stages(2).asInstanceOf[CountVectorizerModel] + countVectorModel.vocabulary + case _ => Array.empty[String] + } + + val ldaModel: LDAModel = model.stages.last.asInstanceOf[LDAModel] + val preprocessor: PipelineModel = + new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", model.stages.dropRight(1)) + + val preprocessedData = preprocessor.transform(data) + + new LDAWrapper( + model, + ldaModel.logLikelihood(preprocessedData), + ldaModel.logPerplexity(preprocessedData), + vocabulary) + } + + override def read: MLReader[LDAWrapper] = new LDAWrapperReader + + override def load(path: String): LDAWrapper = super.load(path) + + class LDAWrapperWriter(instance: LDAWrapper) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("logLikelihood" -> instance.logLikelihood) ~ + ("logPerplexity" -> instance.logPerplexity) ~ + ("vocabulary" -> instance.vocabulary.toList) + val rMetadataJson: String = compact(render(rMetadata)) + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + + instance.pipeline.save(pipelinePath) + } + } + + class LDAWrapperReader extends MLReader[LDAWrapper] { + + override def load(path: String): LDAWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val logLikelihood = (rMetadata \ "logLikelihood").extract[Double] + val logPerplexity = (rMetadata \ "logPerplexity").extract[Double] + val vocabulary = (rMetadata \ "vocabulary").extract[List[String]].toArray + + val pipeline = PipelineModel.load(pipelinePath) + new LDAWrapper(pipeline, logLikelihood, logPerplexity, vocabulary) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 88ac26bc5e3517727cd54814eb67d775be2a62f4..e23af51df5718e55dcde0d6472e5e6d43d7be73f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -44,6 +44,8 @@ private[r] object RWrappers extends MLReader[Object] { GeneralizedLinearRegressionWrapper.load(path) case "org.apache.spark.ml.r.KMeansWrapper" => KMeansWrapper.load(path) + case "org.apache.spark.ml.r.LDAWrapper" => + LDAWrapper.load(path) case "org.apache.spark.ml.r.IsotonicRegressionWrapper" => IsotonicRegressionWrapper.load(path) case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>