diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index eed829356f2be468c1ccef78960334316b65187f..074e9cbebe1d467a2f4679cb4acf494dfa2dc15d 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -733,8 +733,6 @@ setMethod("predict", signature(object = "KMeansModel"), #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p #' is the original probability of that class and t is the class's threshold. #' @param weightCol The weight column name. -#' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions -#' are large, this param could be adjusted to a larger size. #' @param probabilityCol column name for predicted class conditional probabilities. #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model @@ -746,45 +744,35 @@ setMethod("predict", signature(object = "KMeansModel"), #' \dontrun{ #' sparkR.session() #' # binary logistic regression -#' label <- c(0.0, 0.0, 0.0, 1.0, 1.0) -#' features <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) -#' binary_data <- as.data.frame(cbind(label, features)) -#' binary_df <- createDataFrame(binary_data) -#' blr_model <- spark.logit(binary_df, label ~ features, thresholds = 1.0) -#' blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) -#' -#' # summary of binary logistic regression -#' blr_summary <- summary(blr_model) -#' blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) +#' df <- createDataFrame(iris) +#' training <- df[df$Species %in% c("versicolor", "virginica"), ] +#' model <- spark.logit(training, Species ~ ., regParam = 0.5) +#' summary <- summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, training) +#' #' # save fitted model to input path #' path <- "path/to/model" -#' write.ml(blr_model, path) +#' write.ml(model, path) #' #' # can also read back the saved model and predict #' # Note that summary deos not work on loaded model #' savedModel <- read.ml(path) -#' blr_predict2 <- collect(select(predict(savedModel, binary_df), "prediction")) +#' summary(savedModel) #' #' # multinomial logistic regression #' -#' label <- c(0.0, 1.0, 2.0, 0.0, 0.0) -#' feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) -#' feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) -#' feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) -#' feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) -#' data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) -#' df <- createDataFrame(data) +#' df <- createDataFrame(iris) +#' model <- spark.logit(df, Species ~ ., regParam = 0.5) +#' summary <- summary(model) #' -#' # Note that summary of multinomial logistic regression is not implemented yet -#' model <- spark.logit(df, label ~ ., family = "multinomial", thresholds = c(0, 1, 1)) -#' predict1 <- collect(select(predict(model, df), "prediction")) #' } #' @note spark.logit since 2.1.0 setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, tol = 1E-6, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL, aggregationDepth = 2, - probabilityCol = "probability") { + thresholds = 0.5, weightCol = NULL, probabilityCol = "probability") { formula <- paste(deparse(formula), collapse = "") if (is.null(weightCol)) { @@ -796,8 +784,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") as.numeric(elasticNetParam), as.integer(maxIter), as.numeric(tol), as.character(family), as.logical(standardization), as.array(thresholds), - as.character(weightCol), as.integer(aggregationDepth), - as.character(probabilityCol)) + as.character(weightCol), as.character(probabilityCol)) new("LogisticRegressionModel", jobj = jobj) }) @@ -817,10 +804,7 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), # Get the summary of an LogisticRegressionModel #' @param object an LogisticRegressionModel fitted by \code{spark.logit} -#' @return \code{summary} returns the Binary Logistic regression results of a given model as list, -#' including roc, areaUnderROC, pr, fMeasureByThreshold, precisionByThreshold, -#' recallByThreshold, totalIterations, objectiveHistory. Note that Multinomial logistic -#' regression summary is not available now. +#' @return \code{summary} returns coefficients matrix of the fitted model #' @rdname spark.logit #' @aliases summary,LogisticRegressionModel-method #' @export @@ -828,33 +812,21 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), setMethod("summary", signature(object = "LogisticRegressionModel"), function(object) { jobj <- object@jobj - is.loaded <- callJMethod(jobj, "isLoaded") - - if (is.loaded) { - stop("Loaded model doesn't have training summary.") + features <- callJMethod(jobj, "rFeatures") + labels <- callJMethod(jobj, "labels") + coefficients <- callJMethod(jobj, "rCoefficients") + nCol <- length(coefficients) / length(features) + coefficients <- matrix(coefficients, ncol = nCol) + # If nCol == 1, means this is a binomial logistic regression model with pivoting. + # Otherwise, it's a multinomial logistic regression model without pivoting. + if (nCol == 1) { + colnames(coefficients) <- c("Estimate") + } else { + colnames(coefficients) <- unlist(labels) } + rownames(coefficients) <- unlist(features) - roc <- dataFrame(callJMethod(jobj, "roc")) - - areaUnderROC <- callJMethod(jobj, "areaUnderROC") - - pr <- dataFrame(callJMethod(jobj, "pr")) - - fMeasureByThreshold <- dataFrame(callJMethod(jobj, "fMeasureByThreshold")) - - precisionByThreshold <- dataFrame(callJMethod(jobj, "precisionByThreshold")) - - recallByThreshold <- dataFrame(callJMethod(jobj, "recallByThreshold")) - - totalIterations <- callJMethod(jobj, "totalIterations") - - objectiveHistory <- callJMethod(jobj, "objectiveHistory") - - list(roc = roc, areaUnderROC = areaUnderROC, pr = pr, - fMeasureByThreshold = fMeasureByThreshold, - precisionByThreshold = precisionByThreshold, - recallByThreshold = recallByThreshold, - totalIterations = totalIterations, objectiveHistory = objectiveHistory) + list(coefficients = coefficients) }) #' Multilayer Perceptron Classification Model diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index dcfeeb4cd2aa160a1d26ede9b8a589d33febb7c9..0802a2ae48e47e79473a9fd7dea84b516573b48b 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -635,68 +635,141 @@ test_that("spark.isotonicRegression", { }) test_that("spark.logit", { - # test binary logistic regression - label <- c(0.0, 0.0, 0.0, 1.0, 1.0) - feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) - binary_data <- as.data.frame(cbind(label, feature)) - binary_df <- createDataFrame(binary_data) - - blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0) - blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) - expect_equal(blr_predict$prediction, c("0.0", "0.0", "0.0", "0.0", "0.0")) - blr_model1 <- spark.logit(binary_df, label ~ feature, thresholds = 0.0) - blr_predict1 <- collect(select(predict(blr_model1, binary_df), "prediction")) - expect_equal(blr_predict1$prediction, c("1.0", "1.0", "1.0", "1.0", "1.0")) - - # test summary of binary logistic regression - blr_summary <- summary(blr_model) - blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) - expect_equal(blr_fmeasure$threshold, c(0.6565513, 0.6214563, 0.3325291, 0.2115995, 0.1778653), - tolerance = 1e-4) - expect_equal(blr_fmeasure$"F-Measure", c(0.6666667, 0.5000000, 0.8000000, 0.6666667, 0.5714286), - tolerance = 1e-4) - blr_precision <- collect(select(blr_summary$precisionByThreshold, "threshold", "precision")) - expect_equal(blr_precision$precision, c(1.0000000, 0.5000000, 0.6666667, 0.5000000, 0.4000000), - tolerance = 1e-4) - blr_recall <- collect(select(blr_summary$recallByThreshold, "threshold", "recall")) - expect_equal(blr_recall$recall, c(0.5000000, 0.5000000, 1.0000000, 1.0000000, 1.0000000), - tolerance = 1e-4) + # R code to reproduce the result. + # nolint start + #' library(glmnet) + #' iris.x = as.matrix(iris[, 1:4]) + #' iris.y = as.factor(as.character(iris[, 5])) + #' logit = glmnet(iris.x, iris.y, family="multinomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # $setosa + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 1.0981324 + # Sepal.Length -0.2909860 + # Sepal.Width 0.5510907 + # Petal.Length -0.1915217 + # Petal.Width -0.4211946 + # + # $versicolor + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 1.520061e+00 + # Sepal.Length 2.524501e-02 + # Sepal.Width -5.310313e-01 + # Petal.Length 3.656543e-02 + # Petal.Width -3.144464e-05 + # + # $virginica + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # -2.61819385 + # Sepal.Length 0.26574097 + # Sepal.Width -0.02005932 + # Petal.Length 0.15495629 + # Petal.Width 0.42122607 + # nolint end - # test model save and read - modelPath <- tempfile(pattern = "spark-logisticRegression", fileext = ".tmp") - write.ml(blr_model, modelPath) - expect_error(write.ml(blr_model, modelPath)) - write.ml(blr_model, modelPath, overwrite = TRUE) - blr_model2 <- read.ml(modelPath) - blr_predict2 <- collect(select(predict(blr_model2, binary_df), "prediction")) - expect_equal(blr_predict$prediction, blr_predict2$prediction) - expect_error(summary(blr_model2)) + # Test multinomial logistic regression againt three classes + df <- suppressWarnings(createDataFrame(iris)) + model <- spark.logit(df, Species ~ ., regParam = 0.5) + summary <- summary(model) + versicolorCoefsR <- c(1.52, 0.03, -0.53, 0.04, 0.00) + virginicaCoefsR <- c(-2.62, 0.27, -0.02, 0.16, 0.42) + setosaCoefsR <- c(1.10, -0.29, 0.55, -0.19, -0.42) + versicolorCoefs <- unlist(summary$coefficients[, "versicolor"]) + virginicaCoefs <- unlist(summary$coefficients[, "virginica"]) + setosaCoefs <- unlist(summary$coefficients[, "setosa"]) + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) + + # Test model save and load + modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) unlink(modelPath) - # test prediction label as text - training <- suppressWarnings(createDataFrame(iris)) - binomial_training <- training[training$Species %in% c("versicolor", "virginica"), ] - binomial_model <- spark.logit(binomial_training, Species ~ Sepal_Length + Sepal_Width) - prediction <- predict(binomial_model, binomial_training) + # R code to reproduce the result. + # nolint start + #' library(glmnet) + #' iris2 <- iris[iris$Species %in% c("versicolor", "virginica"), ] + #' iris.x = as.matrix(iris2[, 1:4]) + #' iris.y = as.factor(as.character(iris2[, 5])) + #' logit = glmnet(iris.x, iris.y, family="multinomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # $versicolor + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 3.93844796 + # Sepal.Length -0.13538675 + # Sepal.Width -0.02386443 + # Petal.Length -0.35076451 + # Petal.Width -0.77971954 + # + # $virginica + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # -3.93844796 + # Sepal.Length 0.13538675 + # Sepal.Width 0.02386443 + # Petal.Length 0.35076451 + # Petal.Width 0.77971954 + # + #' logit = glmnet(iris.x, iris.y, family="binomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # (Intercept) -6.0824412 + # Sepal.Length 0.2458260 + # Sepal.Width 0.1642093 + # Petal.Length 0.4759487 + # Petal.Width 1.0383948 + # + # nolint end + + # Test multinomial logistic regression againt two classes + df <- suppressWarnings(createDataFrame(iris)) + training <- df[df$Species %in% c("versicolor", "virginica"), ] + model <- spark.logit(training, Species ~ ., regParam = 0.5, family = "multinomial") + summary <- summary(model) + versicolorCoefsR <- c(3.94, -0.16, -0.02, -0.35, -0.78) + virginicaCoefsR <- c(-3.94, 0.16, -0.02, 0.35, 0.78) + versicolorCoefs <- unlist(summary$coefficients[, "versicolor"]) + virginicaCoefs <- unlist(summary$coefficients[, "virginica"]) + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + + # Test binomial logistic regression againt two classes + model <- spark.logit(training, Species ~ ., regParam = 0.5) + summary <- summary(model) + coefsR <- c(-6.08, 0.25, 0.16, 0.48, 1.04) + coefs <- unlist(summary$coefficients[, "Estimate"]) + expect_true(all(abs(coefsR - coefs) < 0.1)) + + # Test prediction with string label + prediction <- predict(model, training) expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character") - expected <- c("virginica", "virginica", "virginica", "versicolor", "virginica", - "versicolor", "virginica", "versicolor", "virginica", "versicolor") + expected <- c("versicolor", "versicolor", "virginica", "versicolor", "versicolor", + "versicolor", "versicolor", "versicolor", "versicolor", "versicolor") expect_equal(as.list(take(select(prediction, "prediction"), 10))[[1]], expected) - # test multinomial logistic regression - label <- c(0.0, 1.0, 2.0, 0.0, 0.0) - feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) - feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) - feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) - feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) - data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) + # Test prediction with numeric label + label <- c(0.0, 0.0, 0.0, 1.0, 1.0) + feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) + data <- as.data.frame(cbind(label, feature)) df <- createDataFrame(data) - - model <- spark.logit(df, label ~., family = "multinomial", thresholds = c(0, 1, 1)) - predict1 <- collect(select(predict(model, df), "prediction")) - expect_equal(predict1$prediction, c("0.0", "0.0", "0.0", "0.0", "0.0")) - # Summary of multinomial logistic regression is not implemented yet - expect_error(summary(model)) + model <- spark.logit(df, label ~ feature) + prediction <- collect(select(predict(model, df), "prediction")) + expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0")) }) test_that("spark.gaussianMixture", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index 9fe6202980fca16cbc69e5f4a8c50c940098872c..7f0f3cea2124a37b255c06824da268cbfcb514c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -23,8 +23,9 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.feature.{IndexToString, RFormula} +import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -32,38 +33,48 @@ import org.apache.spark.sql.{DataFrame, Dataset} private[r] class LogisticRegressionWrapper private ( val pipeline: PipelineModel, val features: Array[String], - val isLoaded: Boolean = false) extends MLWritable { + val labels: Array[String]) extends MLWritable { import LogisticRegressionWrapper._ - private val logisticRegressionModel: LogisticRegressionModel = + private val lrModel: LogisticRegressionModel = pipeline.stages(1).asInstanceOf[LogisticRegressionModel] - lazy val totalIterations: Int = logisticRegressionModel.summary.totalIterations - - lazy val objectiveHistory: Array[Double] = logisticRegressionModel.summary.objectiveHistory - - lazy val blrSummary = - logisticRegressionModel.summary.asInstanceOf[BinaryLogisticRegressionSummary] - - lazy val roc: DataFrame = blrSummary.roc - - lazy val areaUnderROC: Double = blrSummary.areaUnderROC - - lazy val pr: DataFrame = blrSummary.pr - - lazy val fMeasureByThreshold: DataFrame = blrSummary.fMeasureByThreshold - - lazy val precisionByThreshold: DataFrame = blrSummary.precisionByThreshold + val rFeatures: Array[String] = if (lrModel.getFitIntercept) { + Array("(Intercept)") ++ features + } else { + features + } - lazy val recallByThreshold: DataFrame = blrSummary.recallByThreshold + val rCoefficients: Array[Double] = { + val numRows = lrModel.coefficientMatrix.numRows + val numCols = lrModel.coefficientMatrix.numCols + val numColsWithIntercept = if (lrModel.getFitIntercept) numCols + 1 else numCols + val coefficients: Array[Double] = new Array[Double](numRows * numColsWithIntercept) + val coefficientVectors: Seq[Vector] = lrModel.coefficientMatrix.rowIter.toSeq + var i = 0 + if (lrModel.getFitIntercept) { + while (i < numRows) { + coefficients(i * numColsWithIntercept) = lrModel.interceptVector(i) + System.arraycopy(coefficientVectors(i).toArray, 0, + coefficients, i * numColsWithIntercept + 1, numCols) + i += 1 + } + } else { + while (i < numRows) { + System.arraycopy(coefficientVectors(i).toArray, 0, + coefficients, i * numColsWithIntercept, numCols) + i += 1 + } + } + coefficients + } def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset) .drop(PREDICTED_LABEL_INDEX_COL) - .drop(logisticRegressionModel.getFeaturesCol) - .drop(logisticRegressionModel.getLabelCol) - + .drop(lrModel.getFeaturesCol) + .drop(lrModel.getLabelCol) } override def write: MLWriter = new LogisticRegressionWrapper.LogisticRegressionWrapperWriter(this) @@ -86,8 +97,7 @@ private[r] object LogisticRegressionWrapper standardization: Boolean, thresholds: Array[Double], weightCol: String, - aggregationDepth: Int, - probability: String + probabilityCol: String ): LogisticRegressionWrapper = { val rFormula = new RFormula() @@ -102,7 +112,7 @@ private[r] object LogisticRegressionWrapper val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) // assemble and fit the pipeline - val logisticRegression = new LogisticRegression() + val lr = new LogisticRegression() .setRegParam(regParam) .setElasticNetParam(elasticNetParam) .setMaxIter(maxIter) @@ -111,16 +121,15 @@ private[r] object LogisticRegressionWrapper .setFamily(family) .setStandardization(standardization) .setWeightCol(weightCol) - .setAggregationDepth(aggregationDepth) .setFeaturesCol(rFormula.getFeaturesCol) .setLabelCol(rFormula.getLabelCol) - .setProbabilityCol(probability) + .setProbabilityCol(probabilityCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) if (thresholds.length > 1) { - logisticRegression.setThresholds(thresholds) + lr.setThresholds(thresholds) } else { - logisticRegression.setThreshold(thresholds(0)) + lr.setThreshold(thresholds(0)) } val idxToStr = new IndexToString() @@ -129,10 +138,10 @@ private[r] object LogisticRegressionWrapper .setLabels(labels) val pipeline = new Pipeline() - .setStages(Array(rFormulaModel, logisticRegression, idxToStr)) + .setStages(Array(rFormulaModel, lr, idxToStr)) .fit(data) - new LogisticRegressionWrapper(pipeline, features) + new LogisticRegressionWrapper(pipeline, features, labels) } override def read: MLReader[LogisticRegressionWrapper] = new LogisticRegressionWrapperReader @@ -146,7 +155,8 @@ private[r] object LogisticRegressionWrapper val pipelinePath = new Path(path, "pipeline").toString val rMetadata = ("class" -> instance.getClass.getName) ~ - ("features" -> instance.features.toSeq) + ("features" -> instance.features.toSeq) ~ + ("labels" -> instance.labels.toSeq) val rMetadataJson: String = compact(render(rMetadata)) sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) @@ -164,9 +174,10 @@ private[r] object LogisticRegressionWrapper val rMetadataStr = sc.textFile(rMetadataPath, 1).first() val rMetadata = parse(rMetadataStr) val features = (rMetadata \ "features").extract[Array[String]] + val labels = (rMetadata \ "labels").extract[Array[String]] val pipeline = PipelineModel.load(pipelinePath) - new LogisticRegressionWrapper(pipeline, features, isLoaded = true) + new LogisticRegressionWrapper(pipeline, features, labels) } } -} \ No newline at end of file +}