diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 336f2fc114309b99834ee3cc8c40fb8af9208bdd..ae98e24a75681df08a5f3cbe2d9fd17fa09106cd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -339,10 +339,15 @@ private object BisectingKMeans extends Serializable { assignments.map { case (index, v) => if (divisibleIndices.contains(index)) { val children = Seq(leftChildIndex(index), rightChildIndex(index)) - val selected = children.minBy { child => - KMeans.fastSquaredDistance(newClusterCenters(child), v) + val newClusterChildren = children.filter(newClusterCenters.contains(_)) + if (newClusterChildren.nonEmpty) { + val selected = newClusterChildren.minBy { child => + KMeans.fastSquaredDistance(newClusterCenters(child), v) + } + (selected, v) + } else { + (index, v) } - (selected, v) } else { (index, v) } @@ -372,12 +377,12 @@ private object BisectingKMeans extends Serializable { internalIndex -= 1 val leftIndex = leftChildIndex(rawIndex) val rightIndex = rightChildIndex(rawIndex) - val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex => + val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_)) + val height = math.sqrt(indexes.map { childIndex => KMeans.fastSquaredDistance(center, clusters(childIndex).center) }.max) - val left = buildSubTree(leftIndex) - val right = buildSubTree(rightIndex) - new ClusteringTreeNode(index, size, center, cost, height, Array(left, right)) + val children = indexes.map(buildSubTree(_)).toArray + new ClusteringTreeNode(index, size, center, cost, height, children) } else { val index = leafIndex leafIndex += 1 diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index fc491cd6161fda69a1c30ee261f76b3a67f30df1..30513c1e276aec28c3ce4b9db4d20a369192539d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -29,9 +29,12 @@ class BisectingKMeansSuite final val k = 5 @transient var dataset: Dataset[_] = _ + @transient var sparseDataset: Dataset[_] = _ + override def beforeAll(): Unit = { super.beforeAll() dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k) + sparseDataset = KMeansSuite.generateSparseData(spark, 10, 1000, 42) } test("default parameters") { @@ -51,6 +54,22 @@ class BisectingKMeansSuite assert(copiedModel.hasSummary) } + test("SPARK-16473: Verify Bisecting K-Means does not fail in edge case where" + + "one cluster is empty after split") { + val bkm = new BisectingKMeans() + .setK(k) + .setMinDivisibleClusterSize(4) + .setMaxIter(4) + .setSeed(123) + + // Verify fit does not fail on very sparse data + val model = bkm.fit(sparseDataset) + val result = model.transform(sparseDataset) + val numClusters = result.select("prediction").distinct().collect().length + // Verify we hit the edge case + assert(numClusters < k && numClusters > 1) + } + test("setter/getter") { val bkm = new BisectingKMeans() .setK(9) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c1b7242e11a8ff4099aff7130fcc106f12b8869e..e10127f7d108f3e9c31a02981c554b8d1fab6aa4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import scala.util.Random + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap @@ -160,6 +162,17 @@ object KMeansSuite { spark.createDataFrame(rdd) } + def generateSparseData(spark: SparkSession, rows: Int, dim: Int, seed: Int): DataFrame = { + val sc = spark.sparkContext + val random = new Random(seed) + val nnz = random.nextInt(dim) + val rdd = sc.parallelize(1 to rows) + .map(i => Vectors.sparse(dim, random.shuffle(0 to dim - 1).slice(0, nnz).sorted.toArray, + Array.fill(nnz)(random.nextDouble()))) + .map(v => new TestRow(v)) + spark.createDataFrame(rdd) + } + /** * Mapping from all Params to valid settings which differ from the defaults. * This is useful for tests which need to exercise all Params, such as save/load.