diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d1b2a9c9947e181c232c5e6fe8a58d99a89c1cba..c19e599814e540dfa329df7e0c3d5dd19df639a4 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -659,19 +659,69 @@ class DataFrame(object): return DataFrame(self._jdf.distinct(), self.sql_ctx) @since(1.3) - def sample(self, withReplacement, fraction, seed=None): + def sample(self, withReplacement=None, fraction=None, seed=None): """Returns a sampled subset of this :class:`DataFrame`. + :param withReplacement: Sample with replacement or not (default False). + :param fraction: Fraction of rows to generate, range [0.0, 1.0]. + :param seed: Seed for sampling (default a random seed). + .. note:: This is not guaranteed to provide exactly the fraction specified of the total count of the given :class:`DataFrame`. - >>> df.sample(False, 0.5, 42).count() - 2 + .. note:: `fraction` is required and, `withReplacement` and `seed` are optional. + + >>> df = spark.range(10) + >>> df.sample(0.5, 3).count() + 4 + >>> df.sample(fraction=0.5, seed=3).count() + 4 + >>> df.sample(withReplacement=True, fraction=0.5, seed=3).count() + 1 + >>> df.sample(1.0).count() + 10 + >>> df.sample(fraction=1.0).count() + 10 + >>> df.sample(False, fraction=1.0).count() + 10 """ - assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxsize) - rdd = self._jdf.sample(withReplacement, fraction, long(seed)) - return DataFrame(rdd, self.sql_ctx) + + # For the cases below: + # sample(True, 0.5 [, seed]) + # sample(True, fraction=0.5 [, seed]) + # sample(withReplacement=False, fraction=0.5 [, seed]) + is_withReplacement_set = \ + type(withReplacement) == bool and isinstance(fraction, float) + + # For the case below: + # sample(faction=0.5 [, seed]) + is_withReplacement_omitted_kwargs = \ + withReplacement is None and isinstance(fraction, float) + + # For the case below: + # sample(0.5 [, seed]) + is_withReplacement_omitted_args = isinstance(withReplacement, float) + + if not (is_withReplacement_set + or is_withReplacement_omitted_kwargs + or is_withReplacement_omitted_args): + argtypes = [ + str(type(arg)) for arg in [withReplacement, fraction, seed] if arg is not None] + raise TypeError( + "withReplacement (optional), fraction (required) and seed (optional)" + " should be a bool, float and number; however, " + "got [%s]." % ", ".join(argtypes)) + + if is_withReplacement_omitted_args: + if fraction is not None: + seed = fraction + fraction = withReplacement + withReplacement = None + + seed = long(seed) if seed is not None else None + args = [arg for arg in [withReplacement, fraction, seed] if arg is not None] + jdf = self._jdf.sample(*args) + return DataFrame(jdf, self.sql_ctx) @since(1.5) def sampleBy(self, col, fractions, seed=None): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index b3102853ceb53266132102a4391a4b0737184be1..a2a3ceb29d49940e12419da975c576b3bd73d3b3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2108,6 +2108,24 @@ class SQLTests(ReusedPySparkTestCase): plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan() self.assertEqual(1, plan.toString().count("BroadcastHashJoin")) + def test_sample(self): + self.assertRaisesRegexp( + TypeError, + "should be a bool, float and number", + lambda: self.spark.range(1).sample()) + + self.assertRaises( + TypeError, + lambda: self.spark.range(1).sample("a")) + + self.assertRaises( + TypeError, + lambda: self.spark.range(1).sample(seed="abc")) + + self.assertRaises( + IllegalArgumentException, + lambda: self.spark.range(1).sample(-1.0)) + def test_toDF_with_schema_string(self): data = [Row(key=i, value=str(i)) for i in range(100)] rdd = self.sc.parallelize(data, 5) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c6707396af1a8a426eb644fc1c0dd8b0ba250ced..5d8a183b7f8753157ac8099e4f26d4cd7376b004 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1867,7 +1867,8 @@ class Dataset[T] private[sql]( } /** - * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement). + * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), + * using a random seed. * * @param fraction Fraction of rows to generate, range [0.0, 1.0]. *