diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
index 336f2fc114309b99834ee3cc8c40fb8af9208bdd..ae98e24a75681df08a5f3cbe2d9fd17fa09106cd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
@@ -339,10 +339,15 @@ private object BisectingKMeans extends Serializable {
     assignments.map { case (index, v) =>
       if (divisibleIndices.contains(index)) {
         val children = Seq(leftChildIndex(index), rightChildIndex(index))
-        val selected = children.minBy { child =>
-          KMeans.fastSquaredDistance(newClusterCenters(child), v)
+        val newClusterChildren = children.filter(newClusterCenters.contains(_))
+        if (newClusterChildren.nonEmpty) {
+          val selected = newClusterChildren.minBy { child =>
+            KMeans.fastSquaredDistance(newClusterCenters(child), v)
+          }
+          (selected, v)
+        } else {
+          (index, v)
         }
-        (selected, v)
       } else {
         (index, v)
       }
@@ -372,12 +377,12 @@ private object BisectingKMeans extends Serializable {
         internalIndex -= 1
         val leftIndex = leftChildIndex(rawIndex)
         val rightIndex = rightChildIndex(rawIndex)
-        val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex =>
+        val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
+        val height = math.sqrt(indexes.map { childIndex =>
           KMeans.fastSquaredDistance(center, clusters(childIndex).center)
         }.max)
-        val left = buildSubTree(leftIndex)
-        val right = buildSubTree(rightIndex)
-        new ClusteringTreeNode(index, size, center, cost, height, Array(left, right))
+        val children = indexes.map(buildSubTree(_)).toArray
+        new ClusteringTreeNode(index, size, center, cost, height, children)
       } else {
         val index = leafIndex
         leafIndex += 1
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index fc491cd6161fda69a1c30ee261f76b3a67f30df1..30513c1e276aec28c3ce4b9db4d20a369192539d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -29,9 +29,12 @@ class BisectingKMeansSuite
   final val k = 5
   @transient var dataset: Dataset[_] = _
 
+  @transient var sparseDataset: Dataset[_] = _
+
   override def beforeAll(): Unit = {
     super.beforeAll()
     dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k)
+    sparseDataset = KMeansSuite.generateSparseData(spark, 10, 1000, 42)
   }
 
   test("default parameters") {
@@ -51,6 +54,22 @@ class BisectingKMeansSuite
     assert(copiedModel.hasSummary)
   }
 
+  test("SPARK-16473: Verify Bisecting K-Means does not fail in edge case where" +
+    "one cluster is empty after split") {
+    val bkm = new BisectingKMeans()
+      .setK(k)
+      .setMinDivisibleClusterSize(4)
+      .setMaxIter(4)
+      .setSeed(123)
+
+    // Verify fit does not fail on very sparse data
+    val model = bkm.fit(sparseDataset)
+    val result = model.transform(sparseDataset)
+    val numClusters = result.select("prediction").distinct().collect().length
+    // Verify we hit the edge case
+    assert(numClusters < k && numClusters > 1)
+  }
+
   test("setter/getter") {
     val bkm = new BisectingKMeans()
       .setK(9)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index c1b7242e11a8ff4099aff7130fcc106f12b8869e..e10127f7d108f3e9c31a02981c554b8d1fab6aa4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.ml.clustering
 
+import scala.util.Random
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamMap
@@ -160,6 +162,17 @@ object KMeansSuite {
     spark.createDataFrame(rdd)
   }
 
+  def generateSparseData(spark: SparkSession, rows: Int, dim: Int, seed: Int): DataFrame = {
+    val sc = spark.sparkContext
+    val random = new Random(seed)
+    val nnz = random.nextInt(dim)
+    val rdd = sc.parallelize(1 to rows)
+      .map(i => Vectors.sparse(dim, random.shuffle(0 to dim - 1).slice(0, nnz).sorted.toArray,
+        Array.fill(nnz)(random.nextDouble())))
+      .map(v => new TestRow(v))
+    spark.createDataFrame(rdd)
+  }
+
   /**
    * Mapping from all Params to valid settings which differ from the defaults.
    * This is useful for tests which need to exercise all Params, such as save/load.