Skip to content
Snippets Groups Projects
Commit 8d29001d authored by Xusen Yin's avatar Xusen Yin Committed by Xiangrui Meng
Browse files

[SPARK-13011] K-means wrapper in SparkR

https://issues.apache.org/jira/browse/SPARK-13011

Author: Xusen Yin <yinxusen@gmail.com>

Closes #11124 from yinxusen/SPARK-13011.
parent 15e30155
No related branches found
No related tags found
No related merge requests found
......@@ -13,7 +13,9 @@ export("print.jobj")
# MLlib integration
exportMethods("glm",
"predict",
"summary")
"summary",
"kmeans",
"fitted")
# Job group lifecycle management methods
export("setJobGroup",
......
......@@ -1160,3 +1160,11 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") })
#' @rdname rbind
#' @export
setGeneric("rbind", signature = "...")
#' @rdname kmeans
#' @export
setGeneric("kmeans")
#' @rdname fitted
#' @export
setGeneric("fitted")
......@@ -104,11 +104,11 @@ setMethod("predict", signature(object = "PipelineModel"),
setMethod("summary", signature(object = "PipelineModel"),
function(object, ...) {
modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelName", object@model)
"getModelName", object@model)
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelFeatures", object@model)
"getModelFeatures", object@model)
coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelCoefficients", object@model)
"getModelCoefficients", object@model)
if (modelName == "LinearRegressionModel") {
devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelDevianceResiduals", object@model)
......@@ -119,10 +119,76 @@ setMethod("summary", signature(object = "PipelineModel"),
colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
rownames(coefficients) <- unlist(features)
return(list(devianceResiduals = devianceResiduals, coefficients = coefficients))
} else {
} else if (modelName == "LogisticRegressionModel") {
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
} else if (modelName == "KMeansModel") {
modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getKMeansModelSize", object@model)
cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getKMeansCluster", object@model, "classes")
k <- unlist(modelSize)[1]
size <- unlist(modelSize)[-1]
coefficients <- t(matrix(coefficients, ncol = k))
colnames(coefficients) <- unlist(features)
rownames(coefficients) <- 1:k
return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
} else {
stop(paste("Unsupported model", modelName, sep = " "))
}
})
#' Fit a k-means model
#'
#' Fit a k-means model, similarly to R's kmeans().
#'
#' @param x DataFrame for training
#' @param centers Number of centers
#' @param iter.max Maximum iteration number
#' @param algorithm Algorithm choosen to fit the model
#' @return A fitted k-means model
#' @rdname kmeans
#' @export
#' @examples
#'\dontrun{
#' model <- kmeans(x, centers = 2, algorithm="random")
#'}
setMethod("kmeans", signature(x = "DataFrame"),
function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) {
columnNames <- as.array(colnames(x))
algorithm <- match.arg(algorithm)
model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf,
algorithm, iter.max, centers, columnNames)
return(new("PipelineModel", model = model))
})
#' Get fitted result from a model
#'
#' Get fitted result from a model, similarly to R's fitted().
#'
#' @param object A fitted MLlib model
#' @return DataFrame containing fitted values
#' @rdname fitted
#' @export
#' @examples
#'\dontrun{
#' model <- kmeans(trainingData, 2)
#' fitted.model <- fitted(model)
#' showDF(fitted.model)
#'}
setMethod("fitted", signature(object = "PipelineModel"),
function(object, method = c("centers", "classes"), ...) {
modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelName", object@model)
if (modelName == "KMeansModel") {
method <- match.arg(method)
fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getKMeansCluster", object@model, method)
return(dataFrame(fittedResult))
} else {
stop(paste("Unsupported model", modelName, sep = " "))
}
})
......@@ -113,3 +113,31 @@ test_that("summary works on base GLM models", {
baseSummary <- summary(baseModel)
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
})
test_that("kmeans", {
newIris <- iris
newIris$Species <- NULL
training <- suppressWarnings(createDataFrame(sqlContext, newIris))
# Cache the DataFrame here to work around the bug SPARK-13178.
cache(training)
take(training, 1)
model <- kmeans(x = training, centers = 2)
sample <- take(select(predict(model, training), "prediction"), 1)
expect_equal(typeof(sample$prediction), "integer")
expect_equal(sample$prediction, 1)
# Test stats::kmeans is working
statsModel <- kmeans(x = newIris, centers = 2)
expect_equal(unique(statsModel$cluster), c(1, 2))
# Test fitted works on KMeans
fitted.model <- fitted(model)
expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1))
# Test summary works on KMeans
summary.model <- summary(model)
cluster <- summary.model$cluster
expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
})
......@@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
......@@ -135,6 +136,26 @@ class KMeansModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
private var trainingSummary: Option[KMeansSummary] = None
private[clustering] def setSummary(summary: KMeansSummary): this.type = {
this.trainingSummary = Some(summary)
this
}
/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@Since("2.0.0")
def summary: KMeansSummary = trainingSummary match {
case Some(summ) => summ
case None =>
throw new SparkException(
s"No training summary available for the ${this.getClass.getSimpleName}",
new NullPointerException())
}
}
@Since("1.6.0")
......@@ -249,8 +270,9 @@ class KMeans @Since("1.5.0") (
.setSeed($(seed))
.setEpsilon($(tol))
val parentModel = algo.run(rdd)
val model = new KMeansModel(uid, parentModel)
copyValues(model.setParent(this))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol))
model.setSummary(summary)
}
@Since("1.5.0")
......@@ -266,3 +288,22 @@ object KMeans extends DefaultParamsReadable[KMeans] {
override def load(path: String): KMeans = super.load(path)
}
class KMeansSummary private[clustering] (
@Since("2.0.0") @transient val predictions: DataFrame,
@Since("2.0.0") val predictionCol: String,
@Since("2.0.0") val featuresCol: String) extends Serializable {
/**
* Cluster centers of the transformed data.
*/
@Since("2.0.0")
@transient lazy val cluster: DataFrame = predictions.select(predictionCol)
/**
* Size of each cluster.
*/
@Since("2.0.0")
lazy val size: Array[Int] = cluster.map {
case Row(clusterIdx: Int) => (clusterIdx, 1)
}.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2)
}
......@@ -20,7 +20,8 @@ package org.apache.spark.ml.api.r
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.sql.DataFrame
......@@ -51,6 +52,22 @@ private[r] object SparkRWrappers {
pipeline.fit(df)
}
def fitKMeans(
df: DataFrame,
initMode: String,
maxIter: Double,
k: Double,
columns: Array[String]): PipelineModel = {
val assembler = new VectorAssembler().setInputCols(columns)
val kMeans = new KMeans()
.setInitMode(initMode)
.setMaxIter(maxIter.toInt)
.setK(k.toInt)
.setFeaturesCol(assembler.getOutputCol)
val pipeline = new Pipeline().setStages(Array(assembler, kMeans))
pipeline.fit(df)
}
def getModelCoefficients(model: PipelineModel): Array[Double] = {
model.stages.last match {
case m: LinearRegressionModel => {
......@@ -72,6 +89,8 @@ private[r] object SparkRWrappers {
m.coefficients.toArray
}
}
case m: KMeansModel =>
m.clusterCenters.flatMap(_.toArray)
}
}
......@@ -85,6 +104,31 @@ private[r] object SparkRWrappers {
}
}
def getKMeansModelSize(model: PipelineModel): Array[Int] = {
model.stages.last match {
case m: KMeansModel => Array(m.getK) ++ m.summary.size
case other => throw new UnsupportedOperationException(
s"KMeansModel required but ${other.getClass.getSimpleName} found.")
}
}
def getKMeansCluster(model: PipelineModel, method: String): DataFrame = {
model.stages.last match {
case m: KMeansModel =>
if (method == "centers") {
// Drop the assembled vector for easy-print to R side.
m.summary.predictions.drop(m.summary.featuresCol)
} else if (method == "classes") {
m.summary.cluster
} else {
throw new UnsupportedOperationException(
s"Method (centers or classes) required but $method found.")
}
case other => throw new UnsupportedOperationException(
s"KMeansModel required but ${other.getClass.getSimpleName} found.")
}
}
def getModelFeatures(model: PipelineModel): Array[String] = {
model.stages.last match {
case m: LinearRegressionModel =>
......@@ -103,6 +147,10 @@ private[r] object SparkRWrappers {
} else {
attrs.attributes.get.map(_.name.get)
}
case m: KMeansModel =>
val attrs = AttributeGroup.fromStructField(
m.summary.predictions.schema(m.summary.featuresCol))
attrs.attributes.get.map(_.name.get)
}
}
......@@ -112,6 +160,8 @@ private[r] object SparkRWrappers {
"LinearRegressionModel"
case m: LogisticRegressionModel =>
"LogisticRegressionModel"
case m: KMeansModel =>
"KMeansModel"
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment