diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 0f92b5e597c668ddd0c1b19aa8bc6389a802d10f..c0a63d6b3e72159dafc8026c7983e9f55df60190 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -107,7 +107,8 @@ exportMethods("arrange", "write.jdbc", "write.json", "write.parquet", - "write.text") + "write.text", + "ml.save") exportClasses("Column") @@ -299,7 +300,8 @@ export("as.DataFrame", "tableNames", "tables", "uncacheTable", - "print.summary.GeneralizedLinearRegressionModel") + "print.summary.GeneralizedLinearRegressionModel", + "ml.load") export("structField", "structField.jobj", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 04274a12bcc1f049ec3dce752e404c1685341fc7..f654d8330c833559a03356039c0122dba2001086 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1200,3 +1200,7 @@ setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBa #' @rdname survreg #' @export setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") }) + +#' @rdname ml.save +#' @export +setGeneric("ml.save", function(object, path, ...) { standardGeneric("ml.save") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 7dd82963a1a69c519a5dda29312d1f5c3e5b04f5..cda6100e7989ec36b1bd6fbc8682ace78632089c 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -338,6 +338,54 @@ setMethod("naiveBayes", signature(formula = "formula", data = "SparkDataFrame"), return(new("NaiveBayesModel", jobj = jobj)) }) +#' Save the Bernoulli naive Bayes model to the input path. +#' +#' @param object A fitted Bernoulli naive Bayes model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname ml.save +#' @name ml.save +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(sqlContext, infert) +#' model <- naiveBayes(education ~ ., df, laplace = 0) +#' path <- "path/to/model" +#' ml.save(model, path) +#' } +setMethod("ml.save", signature(object = "NaiveBayesModel", path = "character"), + function(object, path, overwrite = FALSE) { + writer <- callJMethod(object@jobj, "write") + if (overwrite) { + writer <- callJMethod(writer, "overwrite") + } + invisible(callJMethod(writer, "save", path)) + }) + +#' Load a fitted MLlib model from the input path. +#' +#' @param path Path of the model to read. +#' @return a fitted MLlib model +#' @rdname ml.load +#' @name ml.load +#' @export +#' @examples +#' \dontrun{ +#' path <- "path/to/model" +#' model <- ml.load(path) +#' } +ml.load <- function(path) { + path <- suppressWarnings(normalizePath(path)) + jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path) + if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) { + return(new("NaiveBayesModel", jobj = jobj)) + } else { + stop(paste("Unsupported model: ", jobj)) + } +} + #' Fit an accelerated failure time (AFT) survival regression model. #' #' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg(). diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 1597306bb67e5e6c3754c7d59e4742df6739861a..63ec84e4970a104ae86cb428e218b8cc2ff269da 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -204,6 +204,18 @@ test_that("naiveBayes", { "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No")) + # Test model save/load + modelPath <- tempfile(pattern = "naiveBayes", fileext = ".tmp") + ml.save(m, modelPath) + expect_error(ml.save(m, modelPath)) + ml.save(m, modelPath, overwrite = TRUE) + m2 <- ml.load(modelPath) + s2 <- summary(m2) + expect_equal(s$apriori, s2$apriori) + expect_equal(s$tables, s2$tables) + + unlink(modelPath) + # Test e1071::naiveBayes if (requireNamespace("e1071", quietly = TRUE)) { expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error())) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index b17207e99bb852ed71a0fcc96f5ea292601aab86..27c7e728810087172e8b11b519dac736d460e717 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -17,16 +17,23 @@ package org.apache.spark.ml.r +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.ml.feature.{IndexToString, RFormula} +import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} private[r] class NaiveBayesWrapper private ( - pipeline: PipelineModel, + val pipeline: PipelineModel, val labels: Array[String], - val features: Array[String]) { + val features: Array[String]) extends MLWritable { import NaiveBayesWrapper._ @@ -41,9 +48,11 @@ private[r] class NaiveBayesWrapper private ( .drop(PREDICTED_LABEL_INDEX_COL) .drop(naiveBayesModel.getFeaturesCol) } + + override def write: MLWriter = new NaiveBayesWrapper.NaiveBayesWrapperWriter(this) } -private[r] object NaiveBayesWrapper { +private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" val PREDICTED_LABEL_COL = "prediction" @@ -74,4 +83,41 @@ private[r] object NaiveBayesWrapper { .fit(data) new NaiveBayesWrapper(pipeline, labels, features) } + + override def read: MLReader[NaiveBayesWrapper] = new NaiveBayesWrapperReader + + override def load(path: String): NaiveBayesWrapper = super.load(path) + + class NaiveBayesWrapperWriter(instance: NaiveBayesWrapper) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("labels" -> instance.labels.toSeq) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + + instance.pipeline.save(pipelinePath) + } + } + + class NaiveBayesWrapperReader extends MLReader[NaiveBayesWrapper] { + + override def load(path: String): NaiveBayesWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val labels = (rMetadata \ "labels").extract[Array[String]] + val features = (rMetadata \ "features").extract[Array[String]] + + val pipeline = PipelineModel.load(pipelinePath) + new NaiveBayesWrapper(pipeline, labels, features) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala new file mode 100644 index 0000000000000000000000000000000000000000..7f6f147532202a47c393698f457da9fe6eecbdb5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkException +import org.apache.spark.ml.util.MLReader + +/** + * This is the Scala stub of SparkR ml.load. It will dispatch the call to corresponding + * model wrapper loading function according the class name extracted from rMetadata of the path. + */ +private[r] object RWrappers extends MLReader[Object] { + + override def load(path: String): Object = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val className = (rMetadata \ "class").extract[String] + className match { + case "org.apache.spark.ml.r.NaiveBayesWrapper" => NaiveBayesWrapper.load(path) + case _ => + throw new SparkException(s"SparkR ml.load does not support load $className") + } + } +}