Skip to content
Snippets Groups Projects
Commit 5b258b8b authored by Ilya Matiach's avatar Ilya Matiach Committed by Joseph K. Bradley
Browse files

[SPARK-16473][MLLIB] Fix BisectingKMeans Algorithm failing in edge case

[SPARK-16473][MLLIB] Fix BisectingKMeans Algorithm failing in edge case where no children exist in updateAssignments

## What changes were proposed in this pull request?

Fix a bug in which BisectingKMeans fails with error:
java.util.NoSuchElementException: key not found: 166
        at scala.collection.MapLike$class.default(MapLike.scala:228)
        at scala.collection.AbstractMap.default(Map.scala:58)
        at scala.collection.MapLike$class.apply(MapLike.scala:141)
        at scala.collection.AbstractMap.apply(Map.scala:58)
        at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1$$anonfun$2.apply$mcDJ$sp(BisectingKMeans.scala:338)
        at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1$$anonfun$2.apply(BisectingKMeans.scala:337)
        at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1$$anonfun$2.apply(BisectingKMeans.scala:337)
        at scala.collection.TraversableOnce$$anonfun$minBy$1.apply(TraversableOnce.scala:231)
        at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:111)
        at scala.collection.immutable.List.foldLeft(List.scala:84)
        at scala.collection.LinearSeqOptimized$class.reduceLeft(LinearSeqOptimized.scala:125)
        at scala.collection.immutable.List.reduceLeft(List.scala:84)
        at scala.collection.TraversableOnce$class.minBy(TraversableOnce.scala:231)
        at scala.collection.AbstractTraversable.minBy(Traversable.scala:105)
        at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1.apply(BisectingKMeans.scala:337)
        at org.apache.spark.mllib.clustering.BisectingKMeans$$anonfun$org$apache$spark$mllib$clustering$BisectingKMeans$$updateAssignments$1.apply(BisectingKMeans.scala:334)
        at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
        at scala.collection.Iterator$$anon$14.hasNext(Iterator.scala:389)

## How was this patch tested?

The dataset was run against the code change to verify that the code works.  I will try to add unit tests to the code.

(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Ilya Matiach <ilmat@microsoft.com>

Closes #16355 from imatiach-msft/ilmat/fix-kmeans.
parent c8aea744
No related branches found
No related tags found
No related merge requests found
......@@ -339,10 +339,15 @@ private object BisectingKMeans extends Serializable {
assignments.map { case (index, v) =>
if (divisibleIndices.contains(index)) {
val children = Seq(leftChildIndex(index), rightChildIndex(index))
val selected = children.minBy { child =>
KMeans.fastSquaredDistance(newClusterCenters(child), v)
val newClusterChildren = children.filter(newClusterCenters.contains(_))
if (newClusterChildren.nonEmpty) {
val selected = newClusterChildren.minBy { child =>
KMeans.fastSquaredDistance(newClusterCenters(child), v)
}
(selected, v)
} else {
(index, v)
}
(selected, v)
} else {
(index, v)
}
......@@ -372,12 +377,12 @@ private object BisectingKMeans extends Serializable {
internalIndex -= 1
val leftIndex = leftChildIndex(rawIndex)
val rightIndex = rightChildIndex(rawIndex)
val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex =>
val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
val height = math.sqrt(indexes.map { childIndex =>
KMeans.fastSquaredDistance(center, clusters(childIndex).center)
}.max)
val left = buildSubTree(leftIndex)
val right = buildSubTree(rightIndex)
new ClusteringTreeNode(index, size, center, cost, height, Array(left, right))
val children = indexes.map(buildSubTree(_)).toArray
new ClusteringTreeNode(index, size, center, cost, height, children)
} else {
val index = leafIndex
leafIndex += 1
......
......@@ -29,9 +29,12 @@ class BisectingKMeansSuite
final val k = 5
@transient var dataset: Dataset[_] = _
@transient var sparseDataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
sparseDataset = KMeansSuite.generateSparseData(spark, 10, 1000, 42)
}
test("default parameters") {
......@@ -51,6 +54,22 @@ class BisectingKMeansSuite
assert(copiedModel.hasSummary)
}
test("SPARK-16473: Verify Bisecting K-Means does not fail in edge case where" +
"one cluster is empty after split") {
val bkm = new BisectingKMeans()
.setK(k)
.setMinDivisibleClusterSize(4)
.setMaxIter(4)
.setSeed(123)
// Verify fit does not fail on very sparse data
val model = bkm.fit(sparseDataset)
val result = model.transform(sparseDataset)
val numClusters = result.select("prediction").distinct().collect().length
// Verify we hit the edge case
assert(numClusters < k && numClusters > 1)
}
test("setter/getter") {
val bkm = new BisectingKMeans()
.setK(9)
......
......@@ -17,6 +17,8 @@
package org.apache.spark.ml.clustering
import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
......@@ -160,6 +162,17 @@ object KMeansSuite {
spark.createDataFrame(rdd)
}
def generateSparseData(spark: SparkSession, rows: Int, dim: Int, seed: Int): DataFrame = {
val sc = spark.sparkContext
val random = new Random(seed)
val nnz = random.nextInt(dim)
val rdd = sc.parallelize(1 to rows)
.map(i => Vectors.sparse(dim, random.shuffle(0 to dim - 1).slice(0, nnz).sorted.toArray,
Array.fill(nnz)(random.nextDouble())))
.map(v => new TestRow(v))
spark.createDataFrame(rdd)
}
/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment