diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index d75b3ef420211070763436be9f486fd839697583..18896fcc4d8c1fe6ae73144838f0be240fab21cf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -118,7 +118,7 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi
     require(totalSamples > 0,
       "QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
     val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
-    val fraction = math.min(requiredSamples.toDouble / dataset.count(), 1.0)
+    val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0)
     dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
   }