From 5cd8ea99f084bee40ee18a0c8e33d0ca0aa6bb60 Mon Sep 17 00:00:00 2001
From: hyukjinkwon <gurwls223@gmail.com>
Date: Fri, 1 Sep 2017 13:01:23 +0900
Subject: [PATCH] [SPARK-21779][PYTHON] Simpler DataFrame.sample API in Python

## What changes were proposed in this pull request?

This PR make `DataFrame.sample(...)` can omit `withReplacement` defaulting `False`, consistently with equivalent Scala / Java API.

In short, the following examples are allowed:

```python
>>> df = spark.range(10)
>>> df.sample(0.5).count()
7
>>> df.sample(fraction=0.5).count()
3
>>> df.sample(0.5, seed=42).count()
5
>>> df.sample(fraction=0.5, seed=42).count()
5
```

In addition, this PR also adds some type checking logics as below:

```python
>>> df = spark.range(10)
>>> df.sample().count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [].
>>> df.sample(True).count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'bool'>].
>>> df.sample(42).count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'int'>].
>>> df.sample(fraction=False, seed="a").count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'bool'>, <type 'str'>].
>>> df.sample(seed=[1]).count()
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'list'>].
>>> df.sample(withReplacement="a", fraction=0.5, seed=1)
...
TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'str'>, <type 'float'>, <type 'int'>].
```

## How was this patch tested?

Manually tested, unit tests added in doc tests and manually checked the built documentation for Python.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #18999 from HyukjinKwon/SPARK-21779.
---
 python/pyspark/sql/dataframe.py               | 64 +++++++++++++++++--
 python/pyspark/sql/tests.py                   | 18 ++++++
 .../scala/org/apache/spark/sql/Dataset.scala  |  3 +-
 3 files changed, 77 insertions(+), 8 deletions(-)

diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index d1b2a9c994..c19e599814 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -659,19 +659,69 @@ class DataFrame(object):
         return DataFrame(self._jdf.distinct(), self.sql_ctx)
 
     @since(1.3)
-    def sample(self, withReplacement, fraction, seed=None):
+    def sample(self, withReplacement=None, fraction=None, seed=None):
         """Returns a sampled subset of this :class:`DataFrame`.
 
+        :param withReplacement: Sample with replacement or not (default False).
+        :param fraction: Fraction of rows to generate, range [0.0, 1.0].
+        :param seed: Seed for sampling (default a random seed).
+
         .. note:: This is not guaranteed to provide exactly the fraction specified of the total
             count of the given :class:`DataFrame`.
 
-        >>> df.sample(False, 0.5, 42).count()
-        2
+        .. note:: `fraction` is required and, `withReplacement` and `seed` are optional.
+
+        >>> df = spark.range(10)
+        >>> df.sample(0.5, 3).count()
+        4
+        >>> df.sample(fraction=0.5, seed=3).count()
+        4
+        >>> df.sample(withReplacement=True, fraction=0.5, seed=3).count()
+        1
+        >>> df.sample(1.0).count()
+        10
+        >>> df.sample(fraction=1.0).count()
+        10
+        >>> df.sample(False, fraction=1.0).count()
+        10
         """
-        assert fraction >= 0.0, "Negative fraction value: %s" % fraction
-        seed = seed if seed is not None else random.randint(0, sys.maxsize)
-        rdd = self._jdf.sample(withReplacement, fraction, long(seed))
-        return DataFrame(rdd, self.sql_ctx)
+
+        # For the cases below:
+        #   sample(True, 0.5 [, seed])
+        #   sample(True, fraction=0.5 [, seed])
+        #   sample(withReplacement=False, fraction=0.5 [, seed])
+        is_withReplacement_set = \
+            type(withReplacement) == bool and isinstance(fraction, float)
+
+        # For the case below:
+        #   sample(faction=0.5 [, seed])
+        is_withReplacement_omitted_kwargs = \
+            withReplacement is None and isinstance(fraction, float)
+
+        # For the case below:
+        #   sample(0.5 [, seed])
+        is_withReplacement_omitted_args = isinstance(withReplacement, float)
+
+        if not (is_withReplacement_set
+                or is_withReplacement_omitted_kwargs
+                or is_withReplacement_omitted_args):
+            argtypes = [
+                str(type(arg)) for arg in [withReplacement, fraction, seed] if arg is not None]
+            raise TypeError(
+                "withReplacement (optional), fraction (required) and seed (optional)"
+                " should be a bool, float and number; however, "
+                "got [%s]." % ", ".join(argtypes))
+
+        if is_withReplacement_omitted_args:
+            if fraction is not None:
+                seed = fraction
+            fraction = withReplacement
+            withReplacement = None
+
+        seed = long(seed) if seed is not None else None
+        args = [arg for arg in [withReplacement, fraction, seed] if arg is not None]
+        jdf = self._jdf.sample(*args)
+        return DataFrame(jdf, self.sql_ctx)
 
     @since(1.5)
     def sampleBy(self, col, fractions, seed=None):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index b3102853ce..a2a3ceb29d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -2108,6 +2108,24 @@ class SQLTests(ReusedPySparkTestCase):
         plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan()
         self.assertEqual(1, plan.toString().count("BroadcastHashJoin"))
 
+    def test_sample(self):
+        self.assertRaisesRegexp(
+            TypeError,
+            "should be a bool, float and number",
+            lambda: self.spark.range(1).sample())
+
+        self.assertRaises(
+            TypeError,
+            lambda: self.spark.range(1).sample("a"))
+
+        self.assertRaises(
+            TypeError,
+            lambda: self.spark.range(1).sample(seed="abc"))
+
+        self.assertRaises(
+            IllegalArgumentException,
+            lambda: self.spark.range(1).sample(-1.0))
+
     def test_toDF_with_schema_string(self):
         data = [Row(key=i, value=str(i)) for i in range(100)]
         rdd = self.sc.parallelize(data, 5)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index c6707396af..5d8a183b7f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1867,7 +1867,8 @@ class Dataset[T] private[sql](
   }
 
   /**
-   * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement).
+   * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement),
+   * using a random seed.
    *
    * @param fraction Fraction of rows to generate, range [0.0, 1.0].
    *
-- 
GitLab