Skip to content
Snippets Groups Projects
Commit a91aaf5a authored by Yanbo Liang's avatar Yanbo Liang Committed by Joseph K. Bradley
Browse files

[SPARK-14375][ML] Unit test for spark.ml KMeansSummary

## What changes were proposed in this pull request?
* Modify ```KMeansSummary.clusterSizes``` method to make it robust to empty clusters.
* Add unit test for spark.ml ```KMeansSummary```.
* Add Since tag.

## How was this patch tested?
unit tests.

cc jkbradley

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #12254 from yanboliang/spark-14375.
parent 0d17593b
No related branches found
No related tags found
No related merge requests found
......@@ -143,6 +143,12 @@ class KMeansModel private[ml] (
this
}
/**
* Return true if there exists summary of model.
*/
@Since("2.0.0")
def hasSummary: Boolean = trainingSummary.nonEmpty
/**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
......@@ -267,7 +273,8 @@ class KMeans @Since("1.5.0") (
.setEpsilon($(tol))
val parentModel = algo.run(rdd)
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol))
val summary = new KMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(summary)
}
......@@ -284,10 +291,22 @@ object KMeans extends DefaultParamsReadable[KMeans] {
override def load(path: String): KMeans = super.load(path)
}
/**
* :: Experimental ::
* Summary of KMeans.
*
* @param predictions [[DataFrame]] produced by [[KMeansModel.transform()]]
* @param predictionCol Name for column of predicted clusters in `predictions`
* @param featuresCol Name for column of features in `predictions`
* @param k Number of clusters
*/
@Since("2.0.0")
@Experimental
class KMeansSummary private[clustering] (
@Since("2.0.0") @transient val predictions: DataFrame,
@Since("2.0.0") val predictionCol: String,
@Since("2.0.0") val featuresCol: String) extends Serializable {
@Since("2.0.0") val featuresCol: String,
@Since("2.0.0") val k: Int) extends Serializable {
/**
* Cluster centers of the transformed data.
......@@ -296,11 +315,15 @@ class KMeansSummary private[clustering] (
@transient lazy val cluster: DataFrame = predictions.select(predictionCol)
/**
* Size of each cluster.
* Size of (number of data points in) each cluster.
*/
@Since("2.0.0")
lazy val clusterSizes: Array[Int] = cluster.rdd.map {
case Row(clusterIdx: Int) => (clusterIdx, 1)
}.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2)
lazy val clusterSizes: Array[Long] = {
val sizes = Array.fill[Long](k)(0)
cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach {
case Row(cluster: Int, count: Long) => sizes(cluster) = count
}
sizes
}
}
......@@ -37,7 +37,7 @@ private[r] class KMeansWrapper private (
lazy val k: Int = kMeansModel.getK
lazy val size: Array[Int] = kMeansModel.summary.clusterSizes
lazy val size: Array[Long] = kMeansModel.summary.clusterSizes
lazy val cluster: DataFrame = kMeansModel.summary.cluster
......
......@@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
}
}
test("fit & transform") {
test("fit, transform, and summary") {
val predictionColName = "kmeans_prediction"
val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = kmeans.fit(dataset)
......@@ -99,6 +99,22 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
// Check validity of model summary
val numRows = dataset.count()
assert(model.hasSummary)
val summary: KMeansSummary = model.summary
assert(summary.predictionCol === predictionColName)
assert(summary.featuresCol === "features")
assert(summary.predictions.count() === numRows)
for (c <- Array(predictionColName, "features")) {
assert(summary.predictions.columns.contains(c))
}
assert(summary.cluster.columns === Array(predictionColName))
val clusterSizes = summary.clusterSizes
assert(clusterSizes.length === k)
assert(clusterSizes.sum === numRows)
assert(clusterSizes.forall(_ >= 0))
}
test("read/write") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment