diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9f34414f64d1b1d0157640c82dfebc00c024bdab..66a3490a640ba5288d0bac5482736a362461c875 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -384,6 +384,26 @@ class SQLTests(ReusedPySparkTestCase): row = df.select(explode(f(*df))).groupBy().sum().first() self.assertEqual(row[0], 10) + df = self.spark.range(3) + res = df.select("id", explode(f(df.id))).collect() + self.assertEqual(res[0][0], 1) + self.assertEqual(res[0][1], 0) + self.assertEqual(res[1][0], 2) + self.assertEqual(res[1][1], 0) + self.assertEqual(res[2][0], 2) + self.assertEqual(res[2][1], 1) + + range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType())) + res = df.select("id", explode(range_udf(df.id))).collect() + self.assertEqual(res[0][0], 0) + self.assertEqual(res[0][1], -1) + self.assertEqual(res[1][0], 0) + self.assertEqual(res[1][1], 0) + self.assertEqual(res[2][0], 1) + self.assertEqual(res[2][1], 0) + self.assertEqual(res[3][0], 1) + self.assertEqual(res[3][1], 1) + def test_udf_with_order_by_and_limit(self): from pyspark.sql.functions import udf my_copy = udf(lambda x: x, IntegerType()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index da42df33663077f9c0e9b64730341ad348634bdb..304367de4cf6a1221e117e79e2fdb7660d67300c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -94,13 +94,13 @@ case class Generate( override def producedAttributes: AttributeSet = AttributeSet(generatorOutput) - def output: Seq[Attribute] = { - val qualified = qualifier.map(q => - // prepend the new qualifier to the existed one - generatorOutput.map(a => a.withQualifier(Some(q))) - ).getOrElse(generatorOutput) + val qualifiedGeneratorOutput: Seq[Attribute] = qualifier.map { q => + // prepend the new qualifier to the existed one + generatorOutput.map(a => a.withQualifier(Some(q))) + }.getOrElse(generatorOutput) - if (join) child.output ++ qualified else qualified + def output: Seq[Attribute] = { + if (join) child.output ++ qualifiedGeneratorOutput else qualifiedGeneratorOutput } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index 19fbf0c162048e0086a5ffd9a1e45fb6e9cfddec..1d9f96bcb5344c19695b3538e016e251d3bd4cc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -45,17 +45,26 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * it. * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. `outer` has no effect when `join` is false. - * @param output the output attributes of this node, which constructed in analysis phase, - * and we can not change it, as the parent node bound with it already. + * @param generatorOutput the qualified output attributes of the generator of this node, which + * constructed in analysis phase, and we can not change it, as the + * parent node bound with it already. */ case class GenerateExec( generator: Generator, join: Boolean, outer: Boolean, - output: Seq[Attribute], + generatorOutput: Seq[Attribute], child: SparkPlan) extends UnaryExecNode { + override def output: Seq[Attribute] = { + if (join) { + child.output ++ generatorOutput + } else { + generatorOutput + } + } + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2308ae8a6c61188eab7225d361e7661f08bb8947..d88cbdfbcfa0e87f16dfbb88ab40045d8a7fbf89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -403,7 +403,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.UnionExec(unionChildren.map(planLater)) :: Nil case g @ logical.Generate(generator, join, outer, _, _, child) => execution.GenerateExec( - generator, join = join, outer = outer, g.output, planLater(child)) :: Nil + generator, join = join, outer = outer, g.qualifiedGeneratorOutput, + planLater(child)) :: Nil case logical.OneRowRelation => execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil case r: logical.Range =>