diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 4404cffc292aaedfe372caf8cec379652231c639..e1b87b28d35ae51d7161480a90f187e9eff45c1a 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -29,7 +29,8 @@ exportMethods("glm", "spark.posterior", "spark.perplexity", "spark.isoreg", - "spark.gaussianMixture") + "spark.gaussianMixture", + "spark.als") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index fe04bcfc7d14df62414478eca59c9ff8397d8ee1..693aa31d3ecab4c23b202502f2bd39841ffc9d54 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1332,3 +1332,7 @@ setGeneric("spark.gaussianMixture", #' @rdname write.ml #' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) + +#' @rdname spark.als +#' @export +setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b9527410a98531e436b27129b454d7ca90331337..36f38fc73a5109bf5f39accd7594edeb567d0cbd 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -74,6 +74,13 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @note GaussianMixtureModel since 2.1.0 setClass("GaussianMixtureModel", representation(jobj = "jobj")) +#' S4 class that represents an ALSModel +#' +#' @param jobj a Java object reference to the backing Scala ALSWrapper +#' @export +#' @note ALSModel since 2.1.0 +setClass("ALSModel", representation(jobj = "jobj")) + #' Saves the MLlib model to the input path #' #' Saves the MLlib model to the input path. For more information, see the specific @@ -82,8 +89,8 @@ setClass("GaussianMixtureModel", representation(jobj = "jobj")) #' @name write.ml #' @export #' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture} -#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.lda} -#' @seealso \link{spark.isoreg} +#' @seealso \link{spark.als}, \link{spark.kmeans}, \link{spark.lda}, \link{spark.naiveBayes} +#' @seealso \link{spark.survreg}, \link{spark.isoreg} #' @seealso \link{read.ml} NULL @@ -95,10 +102,11 @@ NULL #' @name predict #' @export #' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture} -#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg} +#' @seealso \link{spark.als}, \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg} #' @seealso \link{spark.isoreg} NULL + #' Generalized Linear Models #' #' Fits generalized linear model against a Spark DataFrame. @@ -801,6 +809,8 @@ read.ml <- function(path) { return(new("IsotonicRegressionModel", jobj = jobj)) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) { return(new("GaussianMixtureModel", jobj = jobj)) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) { + return(new("ALSModel", jobj = jobj)) } else { stop(paste("Unsupported model: ", jobj)) } @@ -1053,4 +1063,145 @@ setMethod("summary", signature(object = "GaussianMixtureModel"), setMethod("predict", signature(object = "GaussianMixtureModel"), function(object, newData) { return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) - }) \ No newline at end of file + }) + +#' Alternating Least Squares (ALS) for Collaborative Filtering +#' +#' \code{spark.als} learns latent factors in collaborative filtering via alternating least +#' squares. Users can call \code{summary} to obtain fitted latent factors, \code{predict} +#' to make predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-collaborative-filtering.html}{MLlib: +#' Collaborative Filtering}. +#' +#' @param data a SparkDataFrame for training. +#' @param ratingCol column name for ratings. +#' @param userCol column name for user ids. Ids must be (or can be coerced into) integers. +#' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers. +#' @param rank rank of the matrix factorization (> 0). +#' @param reg regularization parameter (>= 0). +#' @param maxIter maximum number of iterations (>= 0). +#' @param nonnegative logical value indicating whether to apply nonnegativity constraints. +#' @param implicitPrefs logical value indicating whether to use implicit preference. +#' @param alpha alpha parameter in the implicit preference formulation (>= 0). +#' @param seed integer seed for random number generation. +#' @param numUserBlocks number of user blocks used to parallelize computation (> 0). +#' @param numItemBlocks number of item blocks used to parallelize computation (> 0). +#' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1). +#' +#' @return \code{spark.als} returns a fitted ALS model +#' @rdname spark.als +#' @aliases spark.als,SparkDataFrame-method +#' @name spark.als +#' @export +#' @examples +#' \dontrun{ +#' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), +#' list(2, 1, 1.0), list(2, 2, 5.0)) +#' df <- createDataFrame(ratings, c("user", "item", "rating")) +#' model <- spark.als(df, "rating", "user", "item") +#' +#' # extract latent factors +#' stats <- summary(model) +#' userFactors <- stats$userFactors +#' itemFactors <- stats$itemFactors +#' +#' # make predictions +#' predicted <- predict(model, df) +#' showDF(predicted) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # set other arguments +#' modelS <- spark.als(df, "rating", "user", "item", rank = 20, +#' reg = 0.1, nonnegative = TRUE) +#' statsS <- summary(modelS) +#' } +#' @note spark.als since 2.1.0 +setMethod("spark.als", signature(data = "SparkDataFrame"), + function(data, ratingCol = "rating", userCol = "user", itemCol = "item", + rank = 10, reg = 1.0, maxIter = 10, nonnegative = FALSE, + implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10, + checkpointInterval = 10, seed = 0) { + + if (!is.numeric(rank) || rank <= 0) { + stop("rank should be a positive number.") + } + if (!is.numeric(reg) || reg < 0) { + stop("reg should be a nonnegative number.") + } + if (!is.numeric(maxIter) || maxIter <= 0) { + stop("maxIter should be a positive number.") + } + + jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper", + "fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank), + reg, as.integer(maxIter), implicitPrefs, alpha, nonnegative, + as.integer(numUserBlocks), as.integer(numItemBlocks), + as.integer(checkpointInterval), as.integer(seed)) + return(new("ALSModel", jobj = jobj)) + }) + +# Returns a summary of the ALS model produced by spark.als. + +#' @param object a fitted ALS model. +#' @return \code{summary} returns a list containing the names of the user column, +#' the item column and the rating column, the estimated user and item factors, +#' rank, regularization parameter and maximum number of iterations used in training. +#' @rdname spark.als +#' @aliases summary,ALSModel-method +#' @export +#' @note summary(ALSModel) since 2.1.0 +setMethod("summary", signature(object = "ALSModel"), +function(object, ...) { + jobj <- object@jobj + user <- callJMethod(jobj, "userCol") + item <- callJMethod(jobj, "itemCol") + rating <- callJMethod(jobj, "ratingCol") + userFactors <- dataFrame(callJMethod(jobj, "userFactors")) + itemFactors <- dataFrame(callJMethod(jobj, "itemFactors")) + rank <- callJMethod(jobj, "rank") + return(list(user = user, item = item, rating = rating, userFactors = userFactors, + itemFactors = itemFactors, rank = rank)) +}) + + +# Makes predictions from an ALS model or a model produced by spark.als. + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values. +#' @rdname spark.als +#' @aliases predict,ALSModel-method +#' @export +#' @note predict(ALSModel) since 2.1.0 +setMethod("predict", signature(object = "ALSModel"), +function(object, newData) { + return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) +}) + + +# Saves the ALS model to the input path. + +#' @param path the directory where the model is saved. +#' @param overwrite logical value indicating whether to overwrite if the output path +#' already exists. Default is FALSE which means throw exception +#' if the output path exists. +#' +#' @rdname spark.als +#' @aliases write.ml,ALSModel,character-method +#' @export +#' @seealso \link{read.ml} +#' @note write.ml(ALSModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "ALSModel", path = "character"), +function(object, path, overwrite = FALSE) { + writer <- callJMethod(object@jobj, "write") + if (overwrite) { + writer <- callJMethod(writer, "overwrite") + } + invisible(callJMethod(writer, "save", path)) +}) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index dfb7a185cd5a392f458ecc220378e1d21a8fc4de..67a3099101cf101b29f58224ab7fc7417794dfb9 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -657,4 +657,44 @@ test_that("spark.posterior and spark.perplexity", { expect_equal(length(local.posterior), sum(unlist(local.posterior))) }) +test_that("spark.als", { + data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), + list(2, 1, 1.0), list(2, 2, 5.0)) + df <- createDataFrame(data, c("user", "item", "score")) + model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item", + rank = 10, maxIter = 5, seed = 0, reg = 0.1) + stats <- summary(model) + expect_equal(stats$rank, 10) + test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item")) + predictions <- collect(predict(model, test)) + + expect_equal(predictions$prediction, c(-0.1380762, 2.6258414, -1.5018409), + tolerance = 1e-4) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats2$rating, "score") + userFactors <- collect(stats$userFactors) + itemFactors <- collect(stats$itemFactors) + userFactors2 <- collect(stats2$userFactors) + itemFactors2 <- collect(stats2$itemFactors) + + orderUser <- order(userFactors$id) + orderUser2 <- order(userFactors2$id) + expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) + expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) + + orderItem <- order(itemFactors$id) + orderItem2 <- order(itemFactors2$id) + expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) + expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) + + unlink(modelPath) +}) + sparkR.session.stop() diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala new file mode 100644 index 0000000000000000000000000000000000000000..ad13cced4667bbad7302524be4c954c1aa283929 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/ALSWrapper.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.recommendation.{ALS, ALSModel} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class ALSWrapper private ( + val alsModel: ALSModel, + val ratingCol: String) extends MLWritable { + + lazy val userCol: String = alsModel.getUserCol + lazy val itemCol: String = alsModel.getItemCol + lazy val userFactors: DataFrame = alsModel.userFactors + lazy val itemFactors: DataFrame = alsModel.itemFactors + lazy val rank: Int = alsModel.rank + + def transform(dataset: Dataset[_]): DataFrame = { + alsModel.transform(dataset) + } + + override def write: MLWriter = new ALSWrapper.ALSWrapperWriter(this) +} + +private[r] object ALSWrapper extends MLReadable[ALSWrapper] { + + def fit( // scalastyle:ignore + data: DataFrame, + ratingCol: String, + userCol: String, + itemCol: String, + rank: Int, + regParam: Double, + maxIter: Int, + implicitPrefs: Boolean, + alpha: Double, + nonnegative: Boolean, + numUserBlocks: Int, + numItemBlocks: Int, + checkpointInterval: Int, + seed: Int): ALSWrapper = { + + val als = new ALS() + .setRatingCol(ratingCol) + .setUserCol(userCol) + .setItemCol(itemCol) + .setRank(rank) + .setRegParam(regParam) + .setMaxIter(maxIter) + .setImplicitPrefs(implicitPrefs) + .setAlpha(alpha) + .setNonnegative(nonnegative) + .setNumBlocks(numUserBlocks) + .setNumItemBlocks(numItemBlocks) + .setCheckpointInterval(checkpointInterval) + .setSeed(seed.toLong) + + val alsModel: ALSModel = als.fit(data) + + new ALSWrapper(alsModel, ratingCol) + } + + override def read: MLReader[ALSWrapper] = new ALSWrapperReader + + override def load(path: String): ALSWrapper = super.load(path) + + class ALSWrapperWriter(instance: ALSWrapper) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val modelPath = new Path(path, "model").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("ratingCol" -> instance.ratingCol) + val rMetadataJson: String = compact(render(rMetadata)) + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + + instance.alsModel.save(modelPath) + } + } + + class ALSWrapperReader extends MLReader[ALSWrapper] { + + override def load(path: String): ALSWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val modelPath = new Path(path, "model").toString + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val ratingCol = (rMetadata \ "ratingCol").extract[String] + val alsModel = ALSModel.load(modelPath) + + new ALSWrapper(alsModel, ratingCol) + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index e23af51df5718e55dcde0d6472e5e6d43d7be73f..51a65f7fc4fe86a55488656cf178ec70fccbc97a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -50,6 +50,8 @@ private[r] object RWrappers extends MLReader[Object] { IsotonicRegressionWrapper.load(path) case "org.apache.spark.ml.r.GaussianMixtureWrapper" => GaussianMixtureWrapper.load(path) + case "org.apache.spark.ml.r.ALSWrapper" => + ALSWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") }