Skip to content
Snippets Groups Projects
Commit 22249afb authored by Yanbo Liang's avatar Yanbo Liang Committed by Xiangrui Meng
Browse files

[SPARK-14303][ML][SPARKR] Define and use KMeansWrapper for SparkR::kmeans

## What changes were proposed in this pull request?
Define and use ```KMeansWrapper``` for ```SparkR::kmeans```. It's only the code refactor for the original ```KMeans``` wrapper.

## How was this patch tested?
Existing tests.

cc mengxr

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #12039 from yanboliang/spark-14059.
parent 26867ebc
No related branches found
No related tags found
No related merge requests found
......@@ -32,6 +32,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
#' @export
setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))
#' @title S4 class that represents a KMeansModel
#' @param jobj a Java object reference to the backing Scala KMeansModel
#' @export
setClass("KMeansModel", representation(jobj = "jobj"))
#' Fits a generalized linear model
#'
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
......@@ -154,17 +159,6 @@ setMethod("summary", signature(object = "PipelineModel"),
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
} else if (modelName == "KMeansModel") {
modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getKMeansModelSize", object@model)
cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getKMeansCluster", object@model, "classes")
k <- unlist(modelSize)[1]
size <- unlist(modelSize)[-1]
coefficients <- t(matrix(coefficients, ncol = k))
colnames(coefficients) <- unlist(features)
rownames(coefficients) <- 1:k
return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
} else {
stop(paste("Unsupported model", modelName, sep = " "))
}
......@@ -213,21 +207,21 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
#' @examples
#' \dontrun{
#' model <- kmeans(x, centers = 2, algorithm="random")
#'}
#' }
setMethod("kmeans", signature(x = "DataFrame"),
function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) {
columnNames <- as.array(colnames(x))
algorithm <- match.arg(algorithm)
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf,
algorithm, iter.max, centers, columnNames)
return(new("PipelineModel", model = model))
jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", x@sdf,
centers, iter.max, algorithm, columnNames)
return(new("KMeansModel", jobj = jobj))
})
#' Get fitted result from a model
#' Get fitted result from a k-means model
#'
#' Get fitted result from a model, similarly to R's fitted().
#' Get fitted result from a k-means model, similarly to R's fitted().
#'
#' @param object A fitted MLlib model
#' @param object A fitted k-means model
#' @return DataFrame containing fitted values
#' @rdname fitted
#' @export
......@@ -237,19 +231,58 @@ setMethod("kmeans", signature(x = "DataFrame"),
#' fitted.model <- fitted(model)
#' showDF(fitted.model)
#'}
setMethod("fitted", signature(object = "PipelineModel"),
setMethod("fitted", signature(object = "KMeansModel"),
function(object, method = c("centers", "classes"), ...) {
modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelName", object@model)
method <- match.arg(method)
return(dataFrame(callJMethod(object@jobj, "fitted", method)))
})
if (modelName == "KMeansModel") {
method <- match.arg(method)
fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getKMeansCluster", object@model, method)
return(dataFrame(fittedResult))
} else {
stop(paste("Unsupported model", modelName, sep = " "))
}
#' Get the summary of a k-means model
#'
#' Returns the summary of a k-means model produced by kmeans(),
#' similarly to R's summary().
#'
#' @param object a fitted k-means model
#' @return the model's coefficients, size and cluster
#' @rdname summary
#' @export
#' @examples
#' \dontrun{
#' model <- kmeans(trainingData, 2)
#' summary(model)
#' }
setMethod("summary", signature(object = "KMeansModel"),
function(object, ...) {
jobj <- object@jobj
features <- callJMethod(jobj, "features")
coefficients <- callJMethod(jobj, "coefficients")
cluster <- callJMethod(jobj, "cluster")
k <- callJMethod(jobj, "k")
size <- callJMethod(jobj, "size")
coefficients <- t(matrix(coefficients, ncol = k))
colnames(coefficients) <- unlist(features)
rownames(coefficients) <- 1:k
return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
})
#' Make predictions from a k-means model
#'
#' Make predictions from a model produced by kmeans().
#'
#' @param object A fitted k-means model
#' @param newData DataFrame for testing
#' @return DataFrame containing predicted labels in a column named "prediction"
#' @rdname predict
#' @export
#' @examples
#' \dontrun{
#' model <- kmeans(trainingData, 2)
#' predicted <- predict(model, testData)
#' showDF(predicted)
#' }
setMethod("predict", signature(object = "KMeansModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})
#' Fit a Bernoulli naive Bayes model
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.r
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.DataFrame
private[r] class KMeansWrapper private (
pipeline: PipelineModel) {
private val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
lazy val coefficients: Array[Double] = kMeansModel.clusterCenters.flatMap(_.toArray)
private lazy val attrs = AttributeGroup.fromStructField(
kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
lazy val features: Array[String] = attrs.attributes.get.map(_.name.get)
lazy val k: Int = kMeansModel.getK
lazy val size: Array[Int] = kMeansModel.summary.size
lazy val cluster: DataFrame = kMeansModel.summary.cluster
def fitted(method: String): DataFrame = {
if (method == "centers") {
kMeansModel.summary.predictions.drop(kMeansModel.getFeaturesCol)
} else if (method == "classes") {
kMeansModel.summary.cluster
} else {
throw new UnsupportedOperationException(
s"Method (centers or classes) required but $method found.")
}
}
def transform(dataset: DataFrame): DataFrame = {
pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol)
}
}
private[r] object KMeansWrapper {
def fit(
data: DataFrame,
k: Double,
maxIter: Double,
initMode: String,
columns: Array[String]): KMeansWrapper = {
val assembler = new VectorAssembler()
.setInputCols(columns)
.setOutputCol("features")
val kMeans = new KMeans()
.setK(k.toInt)
.setMaxIter(maxIter.toInt)
.setInitMode(initMode)
val pipeline = new Pipeline()
.setStages(Array(assembler, kMeans))
.fit(data)
new KMeansWrapper(pipeline)
}
}
......@@ -20,8 +20,7 @@ package org.apache.spark.ml.api.r
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.sql.DataFrame
......@@ -52,22 +51,6 @@ private[r] object SparkRWrappers {
pipeline.fit(df)
}
def fitKMeans(
df: DataFrame,
initMode: String,
maxIter: Double,
k: Double,
columns: Array[String]): PipelineModel = {
val assembler = new VectorAssembler().setInputCols(columns)
val kMeans = new KMeans()
.setInitMode(initMode)
.setMaxIter(maxIter.toInt)
.setK(k.toInt)
.setFeaturesCol(assembler.getOutputCol)
val pipeline = new Pipeline().setStages(Array(assembler, kMeans))
pipeline.fit(df)
}
def getModelCoefficients(model: PipelineModel): Array[Double] = {
model.stages.last match {
case m: LinearRegressionModel => {
......@@ -89,8 +72,6 @@ private[r] object SparkRWrappers {
m.coefficients.toArray
}
}
case m: KMeansModel =>
m.clusterCenters.flatMap(_.toArray)
}
}
......@@ -104,31 +85,6 @@ private[r] object SparkRWrappers {
}
}
def getKMeansModelSize(model: PipelineModel): Array[Int] = {
model.stages.last match {
case m: KMeansModel => Array(m.getK) ++ m.summary.size
case other => throw new UnsupportedOperationException(
s"KMeansModel required but ${other.getClass.getSimpleName} found.")
}
}
def getKMeansCluster(model: PipelineModel, method: String): DataFrame = {
model.stages.last match {
case m: KMeansModel =>
if (method == "centers") {
// Drop the assembled vector for easy-print to R side.
m.summary.predictions.drop(m.summary.featuresCol)
} else if (method == "classes") {
m.summary.cluster
} else {
throw new UnsupportedOperationException(
s"Method (centers or classes) required but $method found.")
}
case other => throw new UnsupportedOperationException(
s"KMeansModel required but ${other.getClass.getSimpleName} found.")
}
}
def getModelFeatures(model: PipelineModel): Array[String] = {
model.stages.last match {
case m: LinearRegressionModel =>
......@@ -147,10 +103,6 @@ private[r] object SparkRWrappers {
} else {
attrs.attributes.get.map(_.name.get)
}
case m: KMeansModel =>
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
attrs.attributes.get.map(_.name.get)
}
}
......@@ -160,8 +112,6 @@ private[r] object SparkRWrappers {
"LinearRegressionModel"
case m: LogisticRegressionModel =>
"LogisticRegressionModel"
case m: KMeansModel =>
"KMeansModel"
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment