diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 152b87351db310ac075d8ce709ab17887aa1743a..213338dfe58a48d062802744a41129c8dee772e9 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -448,6 +448,41 @@ class DataFrame(object):
         rdd = self._jdf.sample(withReplacement, fraction, long(seed))
         return DataFrame(rdd, self.sql_ctx)
 
+    @since(1.5)
+    def sampleBy(self, col, fractions, seed=None):
+        """
+        Returns a stratified sample without replacement based on the
+        fraction given on each stratum.
+
+        :param col: column that defines strata
+        :param fractions:
+            sampling fraction for each stratum. If a stratum is not
+            specified, we treat its fraction as zero.
+        :param seed: random seed
+        :return: a new DataFrame that represents the stratified sample
+
+        >>> from pyspark.sql.functions import col
+        >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))
+        >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
+        >>> sampled.groupBy("key").count().orderBy("key").show()
+        +---+-----+
+        |key|count|
+        +---+-----+
+        |  0|    5|
+        |  1|    8|
+        +---+-----+
+        """
+        if not isinstance(col, str):
+            raise ValueError("col must be a string, but got %r" % type(col))
+        if not isinstance(fractions, dict):
+            raise ValueError("fractions must be a dict but got %r" % type(fractions))
+        for k, v in fractions.items():
+            if not isinstance(k, (float, int, long, basestring)):
+                raise ValueError("key must be float, int, long, or string, but got %r" % type(k))
+            fractions[k] = float(v)
+        seed = seed if seed is not None else random.randint(0, sys.maxsize)
+        return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)
+
     @since(1.4)
     def randomSplit(self, weights, seed=None):
         """Randomly splits this :class:`DataFrame` with the provided weights.
@@ -1322,6 +1357,11 @@ class DataFrameStatFunctions(object):
 
     freqItems.__doc__ = DataFrame.freqItems.__doc__
 
+    def sampleBy(self, col, fractions, seed=None):
+        return self.df.sampleBy(col, fractions, seed)
+
+    sampleBy.__doc__ = DataFrame.sampleBy.__doc__
+
 
 def _test():
     import doctest
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index edb9ed7bba56ace02a48080b68f3d403e775bc7a..955d28771b4df1363216e9e90f865a0a04057fc4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+import java.util.UUID
+
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql.execution.stat._
 
@@ -163,4 +165,26 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
   def freqItems(cols: Seq[String]): DataFrame = {
     FrequentItems.singlePassFreqItems(df, cols, 0.01)
   }
+
+  /**
+   * Returns a stratified sample without replacement based on the fraction given on each stratum.
+   * @param col column that defines strata
+   * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
+   *                  its fraction as zero.
+   * @param seed random seed
+   * @return a new [[DataFrame]] that represents the stratified sample
+   *
+   * @since 1.5.0
+   */
+  def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): DataFrame = {
+    require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
+      s"Fractions must be in [0, 1], but got $fractions.")
+    import org.apache.spark.sql.functions.rand
+    val c = Column(col)
+    val r = rand(seed).as("rand_" + UUID.randomUUID().toString.take(8))
+    val expr = fractions.toSeq.map { case (k, v) =>
+      (c === k) && (r < v)
+    }.reduce(_ || _) || false
+    df.filter(expr)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 0d3ff899dad72959a9b01f6d16c7fb41ae9ff512..3dd46889127ffdf2d191b385f64907126aef8834 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -19,9 +19,9 @@ package org.apache.spark.sql
 
 import org.scalatest.Matchers._
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.functions.col
 
-class DataFrameStatSuite extends SparkFunSuite  {
+class DataFrameStatSuite extends QueryTest {
 
   private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
   import sqlCtx.implicits._
@@ -98,4 +98,12 @@ class DataFrameStatSuite extends SparkFunSuite  {
     val items2 = singleColResults.collect().head
     items2.getSeq[Double](0) should contain (-1.0)
   }
+
+  test("sampleBy") {
+    val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
+    val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
+    checkAnswer(
+      sampled.groupBy("key").count().orderBy("key"),
+      Seq(Row(0, 4), Row(1, 9)))
+  }
 }