From 5cd8ea99f084bee40ee18a0c8e33d0ca0aa6bb60 Mon Sep 17 00:00:00 2001 From: hyukjinkwon <gurwls223@gmail.com> Date: Fri, 1 Sep 2017 13:01:23 +0900 Subject: [PATCH] [SPARK-21779][PYTHON] Simpler DataFrame.sample API in Python ## What changes were proposed in this pull request? This PR make `DataFrame.sample(...)` can omit `withReplacement` defaulting `False`, consistently with equivalent Scala / Java API. In short, the following examples are allowed: ```python >>> df = spark.range(10) >>> df.sample(0.5).count() 7 >>> df.sample(fraction=0.5).count() 3 >>> df.sample(0.5, seed=42).count() 5 >>> df.sample(fraction=0.5, seed=42).count() 5 ``` In addition, this PR also adds some type checking logics as below: ```python >>> df = spark.range(10) >>> df.sample().count() ... TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got []. >>> df.sample(True).count() ... TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'bool'>]. >>> df.sample(42).count() ... TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'int'>]. >>> df.sample(fraction=False, seed="a").count() ... TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'bool'>, <type 'str'>]. >>> df.sample(seed=[1]).count() ... TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'list'>]. >>> df.sample(withReplacement="a", fraction=0.5, seed=1) ... TypeError: withReplacement (optional), fraction (required) and seed (optional) should be a bool, float and number; however, got [<type 'str'>, <type 'float'>, <type 'int'>]. ``` ## How was this patch tested? Manually tested, unit tests added in doc tests and manually checked the built documentation for Python. Author: hyukjinkwon <gurwls223@gmail.com> Closes #18999 from HyukjinKwon/SPARK-21779. --- python/pyspark/sql/dataframe.py | 64 +++++++++++++++++-- python/pyspark/sql/tests.py | 18 ++++++ .../scala/org/apache/spark/sql/Dataset.scala | 3 +- 3 files changed, 77 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d1b2a9c994..c19e599814 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -659,19 +659,69 @@ class DataFrame(object): return DataFrame(self._jdf.distinct(), self.sql_ctx) @since(1.3) - def sample(self, withReplacement, fraction, seed=None): + def sample(self, withReplacement=None, fraction=None, seed=None): """Returns a sampled subset of this :class:`DataFrame`. + :param withReplacement: Sample with replacement or not (default False). + :param fraction: Fraction of rows to generate, range [0.0, 1.0]. + :param seed: Seed for sampling (default a random seed). + .. note:: This is not guaranteed to provide exactly the fraction specified of the total count of the given :class:`DataFrame`. - >>> df.sample(False, 0.5, 42).count() - 2 + .. note:: `fraction` is required and, `withReplacement` and `seed` are optional. + + >>> df = spark.range(10) + >>> df.sample(0.5, 3).count() + 4 + >>> df.sample(fraction=0.5, seed=3).count() + 4 + >>> df.sample(withReplacement=True, fraction=0.5, seed=3).count() + 1 + >>> df.sample(1.0).count() + 10 + >>> df.sample(fraction=1.0).count() + 10 + >>> df.sample(False, fraction=1.0).count() + 10 """ - assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxsize) - rdd = self._jdf.sample(withReplacement, fraction, long(seed)) - return DataFrame(rdd, self.sql_ctx) + + # For the cases below: + # sample(True, 0.5 [, seed]) + # sample(True, fraction=0.5 [, seed]) + # sample(withReplacement=False, fraction=0.5 [, seed]) + is_withReplacement_set = \ + type(withReplacement) == bool and isinstance(fraction, float) + + # For the case below: + # sample(faction=0.5 [, seed]) + is_withReplacement_omitted_kwargs = \ + withReplacement is None and isinstance(fraction, float) + + # For the case below: + # sample(0.5 [, seed]) + is_withReplacement_omitted_args = isinstance(withReplacement, float) + + if not (is_withReplacement_set + or is_withReplacement_omitted_kwargs + or is_withReplacement_omitted_args): + argtypes = [ + str(type(arg)) for arg in [withReplacement, fraction, seed] if arg is not None] + raise TypeError( + "withReplacement (optional), fraction (required) and seed (optional)" + " should be a bool, float and number; however, " + "got [%s]." % ", ".join(argtypes)) + + if is_withReplacement_omitted_args: + if fraction is not None: + seed = fraction + fraction = withReplacement + withReplacement = None + + seed = long(seed) if seed is not None else None + args = [arg for arg in [withReplacement, fraction, seed] if arg is not None] + jdf = self._jdf.sample(*args) + return DataFrame(jdf, self.sql_ctx) @since(1.5) def sampleBy(self, col, fractions, seed=None): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b3102853ce..a2a3ceb29d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2108,6 +2108,24 @@ class SQLTests(ReusedPySparkTestCase): plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan() self.assertEqual(1, plan.toString().count("BroadcastHashJoin")) + def test_sample(self): + self.assertRaisesRegexp( + TypeError, + "should be a bool, float and number", + lambda: self.spark.range(1).sample()) + + self.assertRaises( + TypeError, + lambda: self.spark.range(1).sample("a")) + + self.assertRaises( + TypeError, + lambda: self.spark.range(1).sample(seed="abc")) + + self.assertRaises( + IllegalArgumentException, + lambda: self.spark.range(1).sample(-1.0)) + def test_toDF_with_schema_string(self): data = [Row(key=i, value=str(i)) for i in range(100)] rdd = self.sc.parallelize(data, 5) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c6707396af..5d8a183b7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1867,7 +1867,8 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement). + * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), + * using a random seed. * * @param fraction Fraction of rows to generate, range [0.0, 1.0]. * -- GitLab