diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index fb8d9e75ad09126e4faf431679280bd4091551ce..fa40f9d0bf0d39c5a5fda9532cc05cd2c7e78935 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -98,7 +98,7 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = #' @param object a fitted gaussian mixture model. #' @return \code{summary} returns summary of the fitted model, which is a list. #' The list includes the model's \code{lambda} (lambda), \code{mu} (mu), -#' \code{sigma} (sigma), and \code{posterior} (posterior). +#' \code{sigma} (sigma), \code{loglik} (loglik), and \code{posterior} (posterior). #' @aliases spark.gaussianMixture,SparkDataFrame,formula-method #' @rdname spark.gaussianMixture #' @export @@ -112,6 +112,7 @@ setMethod("summary", signature(object = "GaussianMixtureModel"), sigmaList <- callJMethod(jobj, "sigma") k <- callJMethod(jobj, "k") dim <- callJMethod(jobj, "dim") + loglik <- callJMethod(jobj, "logLikelihood") mu <- c() for (i in 1 : k) { start <- (i - 1) * dim + 1 @@ -129,7 +130,7 @@ setMethod("summary", signature(object = "GaussianMixtureModel"), } else { dataFrame(callJMethod(jobj, "posterior")) } - list(lambda = lambda, mu = mu, sigma = sigma, + list(lambda = lambda, mu = mu, sigma = sigma, loglik = loglik, posterior = posterior, is.loaded = is.loaded) }) diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index cfbdea5c041fbc825c3e636f8a9ac4fbd4ff8dfb..9de8362cde8f205bcc3ec88062365d28693ee42f 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -56,6 +56,10 @@ test_that("spark.gaussianMixture", { # [,1] [,2] # [1,] 0.2961543 0.160783 # [2,] 0.1607830 1.008878 + # + #' model$loglik + # + # [1] -46.89499 # nolint end data <- list(list(-0.6264538, 0.1836433), list(-0.8356286, 1.5952808), list(0.3295078, -0.8204684), list(0.4874291, 0.7383247), @@ -72,9 +76,11 @@ test_that("spark.gaussianMixture", { rMu <- c(0.11731091, -0.06192351, 10.363673, 9.897081) rSigma <- c(0.62049934, 0.06880802, 0.06880802, 1.27431874, 0.2961543, 0.160783, 0.1607830, 1.008878) + rLoglik <- -46.89499 expect_equal(stats$lambda, rLambda, tolerance = 1e-3) expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3) expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3) + expect_equal(unlist(stats$loglik), rLoglik, tolerance = 1e-3) p <- collect(select(predict(model, df), "prediction")) expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1)) @@ -88,6 +94,7 @@ test_that("spark.gaussianMixture", { expect_equal(stats$lambda, stats2$lambda) expect_equal(unlist(stats$mu), unlist(stats2$mu)) expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) + expect_equal(unlist(stats$loglik), unlist(stats2$loglik)) unlink(modelPath) }) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala index b708702959829c9f749e13653d211a60d94722fa..9a98a8b18b1410bf45d0779eb08ec8296ca09a33 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.functions._ private[r] class GaussianMixtureWrapper private ( val pipeline: PipelineModel, val dim: Int, + val logLikelihood: Double, val isLoaded: Boolean = false) extends MLWritable { private val gmm: GaussianMixtureModel = pipeline.stages(1).asInstanceOf[GaussianMixtureModel] @@ -91,7 +92,10 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp .setStages(Array(rFormulaModel, gm)) .fit(data) - new GaussianMixtureWrapper(pipeline, dim) + val gmm: GaussianMixtureModel = pipeline.stages(1).asInstanceOf[GaussianMixtureModel] + val logLikelihood: Double = gmm.summary.logLikelihood + + new GaussianMixtureWrapper(pipeline, dim, logLikelihood) } override def read: MLReader[GaussianMixtureWrapper] = new GaussianMixtureWrapperReader @@ -105,7 +109,8 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp val pipelinePath = new Path(path, "pipeline").toString val rMetadata = ("class" -> instance.getClass.getName) ~ - ("dim" -> instance.dim) + ("dim" -> instance.dim) ~ + ("logLikelihood" -> instance.logLikelihood) val rMetadataJson: String = compact(render(rMetadata)) sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) @@ -124,7 +129,8 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp val rMetadataStr = sc.textFile(rMetadataPath, 1).first() val rMetadata = parse(rMetadataStr) val dim = (rMetadata \ "dim").extract[Int] - new GaussianMixtureWrapper(pipeline, dim, isLoaded = true) + val logLikelihood = (rMetadata \ "logLikelihood").extract[Double] + new GaussianMixtureWrapper(pipeline, dim, logLikelihood, isLoaded = true) } } }