From 215281d88ed664547088309cb432da2fed18b8b7 Mon Sep 17 00:00:00 2001 From: zero323 <zero323@users.noreply.github.com> Date: Wed, 21 Jun 2017 14:59:52 -0700 Subject: [PATCH] [SPARK-20830][PYSPARK][SQL] Add posexplode and posexplode_outer ## What changes were proposed in this pull request? Add Python wrappers for `o.a.s.sql.functions.explode_outer` and `o.a.s.sql.functions.posexplode_outer`. ## How was this patch tested? Unit tests, doctests. Author: zero323 <zero323@users.noreply.github.com> Closes #18049 from zero323/SPARK-20830. --- python/pyspark/sql/functions.py | 65 +++++++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 20 +++++++++- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 240ae65a61..3416c4b118 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1727,6 +1727,71 @@ def posexplode(col): return Column(jc) +@since(2.3) +def explode_outer(col): + """Returns a new row for each element in the given array or map. + Unlike explode, if the array/map is null or empty then null is produced. + + >>> df = spark.createDataFrame( + ... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)], + ... ("id", "an_array", "a_map") + ... ) + >>> df.select("id", "an_array", explode_outer("a_map")).show() + +---+----------+----+-----+ + | id| an_array| key|value| + +---+----------+----+-----+ + | 1|[foo, bar]| x| 1.0| + | 2| []|null| null| + | 3| null|null| null| + +---+----------+----+-----+ + + >>> df.select("id", "a_map", explode_outer("an_array")).show() + +---+-------------+----+ + | id| a_map| col| + +---+-------------+----+ + | 1|Map(x -> 1.0)| foo| + | 1|Map(x -> 1.0)| bar| + | 2| Map()|null| + | 3| null|null| + +---+-------------+----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.explode_outer(_to_java_column(col)) + return Column(jc) + + +@since(2.3) +def posexplode_outer(col): + """Returns a new row for each element with position in the given array or map. + Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced. + + >>> df = spark.createDataFrame( + ... [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)], + ... ("id", "an_array", "a_map") + ... ) + >>> df.select("id", "an_array", posexplode_outer("a_map")).show() + +---+----------+----+----+-----+ + | id| an_array| pos| key|value| + +---+----------+----+----+-----+ + | 1|[foo, bar]| 0| x| 1.0| + | 2| []|null|null| null| + | 3| null|null|null| null| + +---+----------+----+----+-----+ + >>> df.select("id", "a_map", posexplode_outer("an_array")).show() + +---+-------------+----+----+ + | id| a_map| pos| col| + +---+-------------+----+----+ + | 1|Map(x -> 1.0)| 0| foo| + | 1|Map(x -> 1.0)| 1| bar| + | 2| Map()|null|null| + | 3| null|null|null| + +---+-------------+----+----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.posexplode_outer(_to_java_column(col)) + return Column(jc) + + @ignore_unicode_prefix @since(1.6) def get_json_object(col, path): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 31f932a363..3b308579a3 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -258,8 +258,12 @@ class SQLTests(ReusedPySparkTestCase): self.assertTrue(isinstance(columns[1], str)) def test_explode(self): - from pyspark.sql.functions import explode - d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] + from pyspark.sql.functions import explode, explode_outer, posexplode_outer + d = [ + Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}), + Row(a=1, intlist=[], mapfield={}), + Row(a=1, intlist=None, mapfield=None), + ] rdd = self.sc.parallelize(d) data = self.spark.createDataFrame(rdd) @@ -272,6 +276,18 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(result[0][0], "a") self.assertEqual(result[0][1], "b") + result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()] + self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)]) + + result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()] + self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)]) + + result = [x[0] for x in data.select(explode_outer("intlist")).collect()] + self.assertEqual(result, [1, 2, 3, None, None]) + + result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()] + self.assertEqual(result, [('a', 'b'), (None, None), (None, None)]) + def test_and_in_expression(self): self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count()) self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2")) -- GitLab