From 6d0633e3ec9518278fcc7eba58549d4ad3d5813f Mon Sep 17 00:00:00 2001
From: Michael Armbrust <michael@databricks.com>
Date: Thu, 14 May 2015 19:49:44 -0700
Subject: [PATCH] [SPARK-7548] [SQL] Add explode function for DataFrames

Add an `explode` function for dataframes and modify the analyzer so that single table generating functions can be present in a select clause along with other expressions.   There are currently the following restrictions:
 - only top level TGFs are allowed (i.e. no `select(explode('list) + 1)`)
 - only one may be present in a single select to avoid potentially confusing implicit Cartesian products.

TODO:
 - [ ] Python

Author: Michael Armbrust <michael@databricks.com>

Closes #6107 from marmbrus/explodeFunction and squashes the following commits:

7ee2c87 [Michael Armbrust] whitespace
6f80ba3 [Michael Armbrust] Update dataframe.py
c176c89 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explodeFunction
81b5da3 [Michael Armbrust] style
d3faa05 [Michael Armbrust] fix self join case
f9e1e3e [Michael Armbrust] fix python, add since
4f0d0a9 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explodeFunction
e710fe4 [Michael Armbrust] add java and python
52ca0dc [Michael Armbrust] [SPARK-7548][SQL] Add explode function for dataframes.
---
 python/pyspark/sql/dataframe.py               |  12 +-
 python/pyspark/sql/functions.py               |  20 +++
 python/pyspark/sql/tests.py                   |  15 +++
 .../sql/catalyst/analysis/Analyzer.scala      | 117 +++++++++++-------
 .../plans/logical/basicOperators.scala        |   3 +
 .../sql/catalyst/analysis/AnalysisSuite.scala |  10 +-
 .../scala/org/apache/spark/sql/Column.scala   |  27 +++-
 .../org/apache/spark/sql/DataFrame.scala      |   5 +-
 .../org/apache/spark/sql/functions.scala      |   5 +
 .../spark/sql/ColumnExpressionSuite.scala     |  60 +++++++++
 10 files changed, 223 insertions(+), 51 deletions(-)

diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 82cb1c2fdb..2ed95ac8e2 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1511,13 +1511,19 @@ class Column(object):
     isNull = _unary_op("isNull", "True if the current expression is null.")
     isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
 
-    def alias(self, alias):
-        """Return a alias for this column
+    def alias(self, *alias):
+        """Returns this column aliased with a new name or names (in the case of expressions that
+        return more than one column, such as explode).
 
         >>> df.select(df.age.alias("age2")).collect()
         [Row(age2=2), Row(age2=5)]
         """
-        return Column(getattr(self._jc, "as")(alias))
+
+        if len(alias) == 1:
+            return Column(getattr(self._jc, "as")(alias[0]))
+        else:
+            sc = SparkContext._active_spark_context
+            return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
 
     @ignore_unicode_prefix
     def cast(self, dataType):
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index d91265ee0b..6cd6974b0e 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -169,6 +169,26 @@ def approxCountDistinct(col, rsd=None):
     return Column(jc)
 
 
+def explode(col):
+    """Returns a new row for each element in the given array or map.
+
+    >>> from pyspark.sql import Row
+    >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
+    >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect()
+    [Row(anInt=1), Row(anInt=2), Row(anInt=3)]
+
+    >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show()
+    +---+-----+
+    |key|value|
+    +---+-----+
+    |  a|    b|
+    +---+-----+
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.explode(_to_java_column(col))
+    return Column(jc)
+
+
 def coalesce(*cols):
     """Returns the first column that is not null.
 
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1922d03af6..d37c5dbed7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -117,6 +117,21 @@ class SQLTests(ReusedPySparkTestCase):
         ReusedPySparkTestCase.tearDownClass()
         shutil.rmtree(cls.tempdir.name, ignore_errors=True)
 
+    def test_explode(self):
+        from pyspark.sql.functions import explode
+        d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
+        rdd = self.sc.parallelize(d)
+        data = self.sqlCtx.createDataFrame(rdd)
+
+        result = data.select(explode(data.intlist).alias("a")).select("a").collect()
+        self.assertEqual(result[0][0], 1)
+        self.assertEqual(result[1][0], 2)
+        self.assertEqual(result[2][0], 3)
+
+        result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect()
+        self.assertEqual(result[0][0], "a")
+        self.assertEqual(result[0][1], "b")
+
     def test_udf_with_callable(self):
         d = [Row(number=i, squared=i**2) for i in range(10)]
         rdd = self.sc.parallelize(d)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 4baeeb5b58..0b6e1d44b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -73,7 +73,6 @@ class Analyzer(
       ResolveGroupingAnalytics ::
       ResolveSortReferences ::
       ResolveGenerate ::
-      ImplicitGenerate ::
       ResolveFunctions ::
       ExtractWindowExpressions ::
       GlobalAggregates ::
@@ -323,6 +322,11 @@ class Analyzer(
               if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
             (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
 
+          case oldVersion: Generate
+              if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
+            val newOutput = oldVersion.generatorOutput.map(_.newInstance())
+            (oldVersion, oldVersion.copy(generatorOutput = newOutput))
+
           case oldVersion @ Window(_, windowExpressions, _, child)
               if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
                 .nonEmpty =>
@@ -521,66 +525,89 @@ class Analyzer(
   }
 
   /**
-   * When a SELECT clause has only a single expression and that expression is a
-   * [[catalyst.expressions.Generator Generator]] we convert the
-   * [[catalyst.plans.logical.Project Project]] to a [[catalyst.plans.logical.Generate Generate]].
+   * Rewrites table generating expressions that either need one or more of the following in order
+   * to be resolved:
+   *  - concrete attribute references for their output.
+   *  - to be relocated from a SELECT clause (i.e. from  a [[Project]]) into a [[Generate]]).
+   *
+   * Names for the output [[Attributes]] are extracted from [[Alias]] or [[MultiAlias]] expressions
+   * that wrap the [[Generator]]. If more than one [[Generator]] is found in a Project, an
+   * [[AnalysisException]] is throw.
    */
-  object ImplicitGenerate extends Rule[LogicalPlan] {
+  object ResolveGenerate extends Rule[LogicalPlan] {
     def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-      case Project(Seq(Alias(g: Generator, name)), child) =>
-        Generate(g, join = false, outer = false,
-          qualifier = None, UnresolvedAttribute(name) :: Nil, child)
-      case Project(Seq(MultiAlias(g: Generator, names)), child) =>
-        Generate(g, join = false, outer = false,
-          qualifier = None, names.map(UnresolvedAttribute(_)), child)
+      case p: Generate if !p.child.resolved || !p.generator.resolved => p
+      case g: Generate if g.resolved == false =>
+          g.copy(
+            generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name)))
+
+      case p @ Project(projectList, child) =>
+        // Holds the resolved generator, if one exists in the project list.
+        var resolvedGenerator: Generate = null
+
+        val newProjectList = projectList.flatMap {
+          case AliasedGenerator(generator, names) if generator.childrenResolved =>
+            if (resolvedGenerator != null) {
+              failAnalysis(
+                s"Only one generator allowed per select but ${resolvedGenerator.nodeName} and " +
+                s"and ${generator.nodeName} found.")
+            }
+
+            resolvedGenerator =
+              Generate(
+                generator,
+                join = projectList.size > 1, // Only join if there are other expressions in SELECT.
+                outer = false,
+                qualifier = None,
+                generatorOutput = makeGeneratorOutput(generator, names),
+                child)
+
+            resolvedGenerator.generatorOutput
+          case other => other :: Nil
+        }
+
+        if (resolvedGenerator != null) {
+          Project(newProjectList, resolvedGenerator)
+        } else {
+          p
+        }
     }
-  }
 
-  /**
-   * Resolve the Generate, if the output names specified, we will take them, otherwise
-   * we will try to provide the default names, which follow the same rule with Hive.
-   */
-  object ResolveGenerate extends Rule[LogicalPlan] {
-    // Construct the output attributes for the generator,
-    // The output attribute names can be either specified or
-    // auto generated.
+    /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
+    private object AliasedGenerator {
+      def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
+        case Alias(g: Generator, name) => Some((g, name :: Nil))
+        case MultiAlias(g: Generator, names) => Some(g, names)
+        case _ => None
+      }
+    }
+
+    /**
+     * Construct the output attributes for a [[Generator]], given a list of names.  If the list of
+     * names is empty names are assigned by ordinal (i.e., _c0, _c1, ...) to match Hive's defaults.
+     */
     private def makeGeneratorOutput(
         generator: Generator,
-        generatorOutput: Seq[Attribute]): Seq[Attribute] = {
+        names: Seq[String]): Seq[Attribute] = {
       val elementTypes = generator.elementTypes
 
-      if (generatorOutput.length == elementTypes.length) {
-        generatorOutput.zip(elementTypes).map {
-          case (a, (t, nullable)) if !a.resolved =>
-            AttributeReference(a.name, t, nullable)()
-          case (a, _) => a
+      if (names.length == elementTypes.length) {
+        names.zip(elementTypes).map {
+          case (name, (t, nullable)) =>
+            AttributeReference(name, t, nullable)()
         }
-      } else if (generatorOutput.length == 0) {
+      } else if (names.isEmpty) {
         elementTypes.zipWithIndex.map {
           // keep the default column names as Hive does _c0, _c1, _cN
           case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)()
         }
       } else {
-        throw new AnalysisException(
-          s"""
-             |The number of aliases supplied in the AS clause does not match
-             |the number of columns output by the UDTF expected
-             |${elementTypes.size} aliases but got ${generatorOutput.size}
-           """.stripMargin)
+        failAnalysis(
+          "The number of aliases supplied in the AS clause does not match the number of columns " +
+          s"output by the UDTF expected ${elementTypes.size} aliases but got " +
+          s"${names.mkString(",")} ")
       }
     }
-
-    def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-      case p: Generate if !p.child.resolved || !p.generator.resolved => p
-      case p: Generate if p.resolved == false =>
-        // if the generator output names are not specified, we will use the default ones.
-        Generate(
-          p.generator,
-          join = p.join,
-          outer = p.outer,
-          p.qualifier,
-          makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
-    }
   }
 
   /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 0f349f9d11..01f4b6e9bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -59,6 +59,9 @@ case class Generate(
     child: LogicalPlan)
   extends UnaryNode {
 
+  /** The set of all attributes produced by this node. */
+  def generatedSet: AttributeSet = AttributeSet(generatorOutput)
+
   override lazy val resolved: Boolean = {
     generator.resolved &&
       childrenResolved &&
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 6f2f35564d..e1d6ac462f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -72,6 +72,9 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
       StructField("cField", StringType) :: Nil
     ))())
 
+  val listRelation = LocalRelation(
+    AttributeReference("list", ArrayType(IntegerType))())
+
   before {
     caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
     caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
@@ -159,10 +162,15 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
         }
       }
 
-      errorMessages.foreach(m => assert(error.getMessage contains m))
+      errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase))
     }
   }
 
+  errorTest(
+    "too many generators",
+    listRelation.select(Explode('list).as('a), Explode('list).as('b)),
+    "only one generator" :: "explode" :: Nil)
+
   errorTest(
     "unresolved attributes",
     testRelation.select('abcd),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 8bf1320ccb..dc0aeea7c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -18,12 +18,13 @@
 package org.apache.spark.sql
 
 import scala.language.implicitConversions
+import scala.collection.JavaConversions._
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.Logging
 import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue}
 import org.apache.spark.sql.types._
 
 
@@ -727,6 +728,30 @@ class Column(protected[sql] val expr: Expression) extends Logging {
    */
   def as(alias: String): Column = Alias(expr, alias)()
 
+  /**
+   * (Scala-specific) Assigns the given aliases to the results of a table generating function.
+   * {{{
+   *   // Renames colA to colB in select output.
+   *   df.select(explode($"myMap").as("key" :: "value" :: Nil))
+   * }}}
+   *
+   * @group expr_ops
+   * @since 1.4.0
+   */
+  def as(aliases: Seq[String]): Column = MultiAlias(expr, aliases)
+
+  /**
+   * Assigns the given aliases to the results of a table generating function.
+   * {{{
+   *   // Renames colA to colB in select output.
+   *   df.select(explode($"myMap").as("key" :: "value" :: Nil))
+   * }}}
+   *
+   * @group expr_ops
+   * @since 1.4.0
+   */
+  def as(aliases: Array[String]): Column = MultiAlias(expr, aliases)
+
   /**
    * Gives the column an alias.
    * {{{
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 4fd5105c27..2e20c3d3f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -34,7 +34,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.api.python.SerDeUtil
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAttribute, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute, UnresolvedRelation}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
 import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -593,6 +593,9 @@ class DataFrame private[sql](
   def select(cols: Column*): DataFrame = {
     val namedExpressions = cols.map {
       case Column(expr: NamedExpression) => expr
+      // Leave an unaliased explode with an empty list of names since the analzyer will generate the
+      // correct defaults after the nested expression's type has been resolved.
+      case Column(explode: Explode) => MultiAlias(explode, Nil)
       case Column(expr: Expression) => Alias(expr, expr.prettyString)()
     }
     // When user continuously call `select`, speed up analysis by collapsing `Project`
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 4404ad8ad6..6640631cf0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -363,6 +363,11 @@ object functions {
   @scala.annotation.varargs
   def coalesce(e: Column*): Column = Coalesce(e.map(_.expr))
 
+  /**
+   * Creates a new row for each element in the given array or map column.
+   */
+   def explode(e: Column): Column = Explode(e.expr)
+
   /**
    * Converts a string exprsesion to lower case.
    *
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 269e185543..9bdf201b3b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -27,6 +27,66 @@ import org.apache.spark.sql.types._
 class ColumnExpressionSuite extends QueryTest {
   import org.apache.spark.sql.TestData._
 
+  test("single explode") {
+    val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
+    checkAnswer(
+      df.select(explode('intList)),
+      Row(1) :: Row(2) :: Row(3) :: Nil)
+  }
+
+  test("explode and other columns") {
+    val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
+
+    checkAnswer(
+      df.select($"a", explode('intList)),
+      Row(1, 1) ::
+      Row(1, 2) ::
+      Row(1, 3) :: Nil)
+
+    checkAnswer(
+      df.select($"*", explode('intList)),
+      Row(1, Seq(1,2,3), 1) ::
+      Row(1, Seq(1,2,3), 2) ::
+      Row(1, Seq(1,2,3), 3) :: Nil)
+  }
+
+  test("aliased explode") {
+    val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
+
+    checkAnswer(
+      df.select(explode('intList).as('int)).select('int),
+      Row(1) :: Row(2) :: Row(3) :: Nil)
+
+    checkAnswer(
+      df.select(explode('intList).as('int)).select(sum('int)),
+      Row(6) :: Nil)
+  }
+
+  test("explode on map") {
+    val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
+
+    checkAnswer(
+      df.select(explode('map)),
+      Row("a", "b"))
+  }
+
+  test("explode on map with aliases") {
+    val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
+
+    checkAnswer(
+      df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
+      Row("a", "b"))
+  }
+
+  test("self join explode") {
+    val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
+    val exploded = df.select(explode('intList).as('i))
+
+    checkAnswer(
+      exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
+      Row(3) :: Nil)
+  }
+
   test("collect on column produced by a binary operator") {
     val df = Seq((1, 2, 3)).toDF("a", "b", "c")
     checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))
-- 
GitLab