From 215281d88ed664547088309cb432da2fed18b8b7 Mon Sep 17 00:00:00 2001
From: zero323 <zero323@users.noreply.github.com>
Date: Wed, 21 Jun 2017 14:59:52 -0700
Subject: [PATCH] [SPARK-20830][PYSPARK][SQL] Add posexplode and
 posexplode_outer

## What changes were proposed in this pull request?

Add Python wrappers for `o.a.s.sql.functions.explode_outer` and `o.a.s.sql.functions.posexplode_outer`.

## How was this patch tested?

Unit tests, doctests.

Author: zero323 <zero323@users.noreply.github.com>

Closes #18049 from zero323/SPARK-20830.
---
 python/pyspark/sql/functions.py | 65 +++++++++++++++++++++++++++++++++
 python/pyspark/sql/tests.py     | 20 +++++++++-
 2 files changed, 83 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 240ae65a61..3416c4b118 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1727,6 +1727,71 @@ def posexplode(col):
     return Column(jc)
 
 
+@since(2.3)
+def explode_outer(col):
+    """Returns a new row for each element in the given array or map.
+    Unlike explode, if the array/map is null or empty then null is produced.
+
+    >>> df = spark.createDataFrame(
+    ...     [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],
+    ...     ("id", "an_array", "a_map")
+    ... )
+    >>> df.select("id", "an_array", explode_outer("a_map")).show()
+    +---+----------+----+-----+
+    | id|  an_array| key|value|
+    +---+----------+----+-----+
+    |  1|[foo, bar]|   x|  1.0|
+    |  2|        []|null| null|
+    |  3|      null|null| null|
+    +---+----------+----+-----+
+
+    >>> df.select("id", "a_map", explode_outer("an_array")).show()
+    +---+-------------+----+
+    | id|        a_map| col|
+    +---+-------------+----+
+    |  1|Map(x -> 1.0)| foo|
+    |  1|Map(x -> 1.0)| bar|
+    |  2|        Map()|null|
+    |  3|         null|null|
+    +---+-------------+----+
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.explode_outer(_to_java_column(col))
+    return Column(jc)
+
+
+@since(2.3)
+def posexplode_outer(col):
+    """Returns a new row for each element with position in the given array or map.
+    Unlike posexplode, if the array/map is null or empty then the row (null, null) is produced.
+
+    >>> df = spark.createDataFrame(
+    ...     [(1, ["foo", "bar"], {"x": 1.0}), (2, [], {}), (3, None, None)],
+    ...     ("id", "an_array", "a_map")
+    ... )
+    >>> df.select("id", "an_array", posexplode_outer("a_map")).show()
+    +---+----------+----+----+-----+
+    | id|  an_array| pos| key|value|
+    +---+----------+----+----+-----+
+    |  1|[foo, bar]|   0|   x|  1.0|
+    |  2|        []|null|null| null|
+    |  3|      null|null|null| null|
+    +---+----------+----+----+-----+
+    >>> df.select("id", "a_map", posexplode_outer("an_array")).show()
+    +---+-------------+----+----+
+    | id|        a_map| pos| col|
+    +---+-------------+----+----+
+    |  1|Map(x -> 1.0)|   0| foo|
+    |  1|Map(x -> 1.0)|   1| bar|
+    |  2|        Map()|null|null|
+    |  3|         null|null|null|
+    +---+-------------+----+----+
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.posexplode_outer(_to_java_column(col))
+    return Column(jc)
+
+
 @ignore_unicode_prefix
 @since(1.6)
 def get_json_object(col, path):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 31f932a363..3b308579a3 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -258,8 +258,12 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertTrue(isinstance(columns[1], str))
 
     def test_explode(self):
-        from pyspark.sql.functions import explode
-        d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
+        from pyspark.sql.functions import explode, explode_outer, posexplode_outer
+        d = [
+            Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"}),
+            Row(a=1, intlist=[], mapfield={}),
+            Row(a=1, intlist=None, mapfield=None),
+        ]
         rdd = self.sc.parallelize(d)
         data = self.spark.createDataFrame(rdd)
 
@@ -272,6 +276,18 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(result[0][0], "a")
         self.assertEqual(result[0][1], "b")
 
+        result = [tuple(x) for x in data.select(posexplode_outer("intlist")).collect()]
+        self.assertEqual(result, [(0, 1), (1, 2), (2, 3), (None, None), (None, None)])
+
+        result = [tuple(x) for x in data.select(posexplode_outer("mapfield")).collect()]
+        self.assertEqual(result, [(0, 'a', 'b'), (None, None, None), (None, None, None)])
+
+        result = [x[0] for x in data.select(explode_outer("intlist")).collect()]
+        self.assertEqual(result, [1, 2, 3, None, None])
+
+        result = [tuple(x) for x in data.select(explode_outer("mapfield")).collect()]
+        self.assertEqual(result, [('a', 'b'), (None, None), (None, None)])
+
     def test_and_in_expression(self):
         self.assertEqual(4, self.df.filter((self.df.key <= 10) & (self.df.value <= "2")).count())
         self.assertRaises(ValueError, lambda: (self.df.key <= 10) and (self.df.value <= "2"))
-- 
GitLab