diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index efe8b93d8235a579bb88b5bbf07a2848fcba14e5..5c7993af645af9b157d839e3eb5c43da8070d75d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -37,7 +37,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
   with HasInputCol with HasOutputCol with HasSeed {
 
   /**
-   * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must
+   * Number of buckets (quantiles, or categories) into which data points are grouped. Must
    * be >= 2.
    * default: 2
    * @group param
@@ -49,6 +49,21 @@ private[feature] trait QuantileDiscretizerBase extends Params
 
   /** @group getParam */
   def getNumBuckets: Int = getOrDefault(numBuckets)
+
+  /**
+   * Relative error (see documentation for
+   * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] for description)
+   * Must be a number in [0, 1].
+   * default: 0.001
+   * @group param
+   */
+  val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " +
+    "for approxQuantile",
+    ParamValidators.inRange(0.0, 1.0))
+  setDefault(relativeError -> 0.001)
+
+  /** @group getParam */
+  def getRelativeError: Double = getOrDefault(relativeError)
 }
 
 /**
@@ -56,8 +71,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
  * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
  * categorical features. The bin ranges are chosen by taking a sample of the data and dividing it
  * into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity,
- * covering all real values. This attempts to find numBuckets partitions based on a sample of data,
- * but it may find fewer depending on the data sample values.
+ * covering all real values.
  */
 @Experimental
 final class QuantileDiscretizer(override val uid: String)
@@ -65,6 +79,9 @@ final class QuantileDiscretizer(override val uid: String)
 
   def this() = this(Identifiable.randomUID("quantileDiscretizer"))
 
+  /** @group setParam */
+  def setRelativeError(value: Double): this.type = set(relativeError, value)
+
   /** @group setParam */
   def setNumBuckets(value: Int): this.type = set(numBuckets, value)
 
@@ -89,11 +106,11 @@ final class QuantileDiscretizer(override val uid: String)
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): Bucketizer = {
-    val samples = QuantileDiscretizer
-      .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
-      .map { case Row(feature: Double) => feature }
-    val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
-    val splits = QuantileDiscretizer.getSplits(candidates)
+    val splits = dataset.stat.approxQuantile($(inputCol),
+      (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
+    splits(0) = Double.NegativeInfinity
+    splits(splits.length - 1) = Double.PositiveInfinity
+
     val bucketizer = new Bucketizer(uid).setSplits(splits)
     copyValues(bucketizer.setParent(this))
   }
@@ -104,92 +121,6 @@ final class QuantileDiscretizer(override val uid: String)
 @Since("1.6.0")
 object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
 
-  /**
-   * Minimum number of samples required for finding splits, regardless of number of bins.  If
-   * the dataset has fewer rows than this value, the entire dataset will be used.
-   */
-  private[spark] val minSamplesRequired: Int = 10000
-
-  /**
-   * Sampling from the given dataset to collect quantile statistics.
-   */
-  private[feature]
-  def getSampledInput(dataset: Dataset[_], numBins: Int, seed: Long): Array[Row] = {
-    val totalSamples = dataset.count()
-    require(totalSamples > 0,
-      "QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
-    val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
-    val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0)
-    dataset.toDF.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
-      .collect()
-  }
-
-  /**
-   * Compute split points with respect to the sample distribution.
-   */
-  private[feature]
-  def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = {
-    val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
-      m + ((x, m.getOrElse(x, 0) + 1))
-    }
-    val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1))
-    val possibleSplits = valueCounts.length - 1
-    if (possibleSplits <= numSplits) {
-      valueCounts.dropRight(1).map(_._1)
-    } else {
-      val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1))
-      val splitsBuilder = mutable.ArrayBuilder.make[Double]
-      var index = 1
-      // currentCount: sum of counts of values that have been visited
-      var currentCount = valueCounts(0)._2
-      // targetCount: target value for `currentCount`. If `currentCount` is closest value to
-      // `targetCount`, then current value is a split threshold. After finding a split threshold,
-      // `targetCount` is added by stride.
-      var targetCount = stride
-      while (index < valueCounts.length) {
-        val previousCount = currentCount
-        currentCount += valueCounts(index)._2
-        val previousGap = math.abs(previousCount - targetCount)
-        val currentGap = math.abs(currentCount - targetCount)
-        // If adding count of current value to currentCount makes the gap between currentCount and
-        // targetCount smaller, previous value is a split threshold.
-        if (previousGap < currentGap) {
-          splitsBuilder += valueCounts(index - 1)._1
-          targetCount += stride
-        }
-        index += 1
-      }
-      splitsBuilder.result()
-    }
-  }
-
-  /**
-   * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as
-   * needed, and adding a default split value of 0 if no good candidates are found.
-   */
-  private[feature] def getSplits(candidates: Array[Double]): Array[Double] = {
-    val effectiveValues = if (candidates.nonEmpty) {
-      if (candidates.head == Double.NegativeInfinity
-        && candidates.last == Double.PositiveInfinity) {
-        candidates.drop(1).dropRight(1)
-      } else if (candidates.head == Double.NegativeInfinity) {
-        candidates.drop(1)
-      } else if (candidates.last == Double.PositiveInfinity) {
-        candidates.dropRight(1)
-      } else {
-        candidates
-      }
-    } else {
-      candidates
-    }
-
-    if (effectiveValues.isEmpty) {
-      Array(Double.NegativeInfinity, 0, Double.PositiveInfinity)
-    } else {
-      Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity)
-    }
-  }
-
   @Since("1.6.0")
   override def load(path: String): QuantileDiscretizer = super.load(path)
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 25fabf64d5594ee57c969a83f371757e9743fc08..8895d630a08791881cd7f1dd70f58b82af44b398 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -17,78 +17,60 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.{SparkContext, SparkFunSuite}
-import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
+import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.functions.udf
 
 class QuantileDiscretizerSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
-  import org.apache.spark.ml.feature.QuantileDiscretizerSuite._
-
-  test("Test quantile discretizer") {
-    checkDiscretizedData(sc,
-      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
-      10,
-      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
-      Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
-
-    checkDiscretizedData(sc,
-      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
-      4,
-      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
-      Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity"))
-
-    checkDiscretizedData(sc,
-      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
-      3,
-      Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2),
-      Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity"))
+  test("Test observed number of buckets and their sizes match expected values") {
+    val sqlCtx = SQLContext.getOrCreate(sc)
+    import sqlCtx.implicits._
 
-    checkDiscretizedData(sc,
-      Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3),
-      2,
-      Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1),
-      Array("-Infinity, 2.0", "2.0, Infinity"))
+    val datasetSize = 100000
+    val numBuckets = 5
+    val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input")
+    val discretizer = new QuantileDiscretizer()
+      .setInputCol("input")
+      .setOutputCol("result")
+      .setNumBuckets(numBuckets)
+    val result = discretizer.fit(df).transform(df)
 
-  }
+    val observedNumBuckets = result.select("result").distinct.count
+    assert(observedNumBuckets === numBuckets,
+      "Observed number of buckets does not equal expected number of buckets.")
 
-  test("Test getting splits") {
-    val splitTestPoints = Array(
-      Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
-      Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
-      Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
-      Array(Double.NegativeInfinity, Double.PositiveInfinity)
-        -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
-      Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity),
-      Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity),
-      Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity)
-    )
-    for ((ori, res) <- splitTestPoints) {
-      assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.")
+    val relativeError = discretizer.getRelativeError
+    val isGoodBucket = udf {
+      (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize)
     }
+    val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count
+    assert(numGoodBuckets === numBuckets,
+      "Bucket sizes are not within expected relative error tolerance.")
   }
 
-  test("Test splits on dataset larger than minSamplesRequired") {
+  test("Test transform method on unseen data") {
     val sqlCtx = SQLContext.getOrCreate(sc)
     import sqlCtx.implicits._
 
-    val datasetSize = QuantileDiscretizer.minSamplesRequired + 1
-    val numBuckets = 5
-    val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input")
+    val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input")
+    val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input")
     val discretizer = new QuantileDiscretizer()
       .setInputCol("input")
       .setOutputCol("result")
-      .setNumBuckets(numBuckets)
-      .setSeed(1)
+      .setNumBuckets(5)
 
-    val result = discretizer.fit(df).transform(df)
-    val observedNumBuckets = result.select("result").distinct.count
+    val result = discretizer.fit(trainDF).transform(testDF)
+    val firstBucketSize = result.filter(result("result") === 0.0).count
+    val lastBucketSize = result.filter(result("result") === 4.0).count
 
-    assert(observedNumBuckets === numBuckets,
-      "Observed number of buckets does not equal expected number of buckets.")
+    assert(firstBucketSize === 30L,
+      s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.")
+    assert(lastBucketSize === 31L,
+      s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.")
   }
 
   test("read/write") {
@@ -98,34 +80,17 @@ class QuantileDiscretizerSuite
       .setNumBuckets(6)
     testDefaultReadWrite(t)
   }
-}
-
-private object QuantileDiscretizerSuite extends SparkFunSuite {
 
-  def checkDiscretizedData(
-      sc: SparkContext,
-      data: Array[Double],
-      numBucket: Int,
-      expectedResult: Array[Double],
-      expectedAttrs: Array[String]): Unit = {
+  test("Verify resulting model has parent") {
     val sqlCtx = SQLContext.getOrCreate(sc)
     import sqlCtx.implicits._
 
-    val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
-    val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
-      .setNumBuckets(numBucket).setSeed(1)
+    val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input")
+    val discretizer = new QuantileDiscretizer()
+      .setInputCol("input")
+      .setOutputCol("result")
+      .setNumBuckets(5)
     val model = discretizer.fit(df)
     assert(model.hasParent)
-    val result = model.transform(df)
-
-    val transformedFeatures = result.select("result").collect()
-      .map { case Row(transformedFeature: Double) => transformedFeature }
-    val transformedAttrs = Attribute.fromStructField(result.schema("result"))
-      .asInstanceOf[NominalAttribute].values.get
-
-    assert(transformedFeatures === expectedResult,
-      "Transformed features do not equal expected features.")
-    assert(transformedAttrs === expectedAttrs,
-      "Transformed attributes do not equal expected attributes.")
   }
 }