From a848d552ef6b5d0d3bb3b2da903478437a8b10aa Mon Sep 17 00:00:00 2001 From: hyukjinkwon <gurwls223@gmail.com> Date: Tue, 4 Jul 2017 11:35:08 +0900 Subject: [PATCH] [SPARK-21264][PYTHON] Call cross join path in join without 'on' and with 'how' ## What changes were proposed in this pull request? Currently, it throws a NPE when missing columns but join type is speicified in join at PySpark as below: ```python spark.conf.set("spark.sql.crossJoin.enabled", "false") spark.range(1).join(spark.range(1), how="inner").show() ``` ``` Traceback (most recent call last): ... py4j.protocol.Py4JJavaError: An error occurred while calling o66.join. : java.lang.NullPointerException at org.apache.spark.sql.Dataset.join(Dataset.scala:931) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) ... ``` ```python spark.conf.set("spark.sql.crossJoin.enabled", "true") spark.range(1).join(spark.range(1), how="inner").show() ``` ``` ... py4j.protocol.Py4JJavaError: An error occurred while calling o84.join. : java.lang.NullPointerException at org.apache.spark.sql.Dataset.join(Dataset.scala:931) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) ... ``` This PR suggests to follow Scala's one as below: ```scala scala> spark.conf.set("spark.sql.crossJoin.enabled", "false") scala> spark.range(1).join(spark.range(1), Seq.empty[String], "inner").show() ``` ``` org.apache.spark.sql.AnalysisException: Detected cartesian product for INNER join between logical plans Range (0, 1, step=1, splits=Some(8)) and Range (0, 1, step=1, splits=Some(8)) Join condition is missing or trivial. Use the CROSS JOIN syntax to allow cartesian products between these relations.; ... ``` ```scala scala> spark.conf.set("spark.sql.crossJoin.enabled", "true") scala> spark.range(1).join(spark.range(1), Seq.empty[String], "inner").show() ``` ``` +---+---+ | id| id| +---+---+ | 0| 0| +---+---+ ``` **After** ```python spark.conf.set("spark.sql.crossJoin.enabled", "false") spark.range(1).join(spark.range(1), how="inner").show() ``` ``` Traceback (most recent call last): ... pyspark.sql.utils.AnalysisException: u'Detected cartesian product for INNER join between logical plans\nRange (0, 1, step=1, splits=Some(8))\nand\nRange (0, 1, step=1, splits=Some(8))\nJoin condition is missing or trivial.\nUse the CROSS JOIN syntax to allow cartesian products between these relations.;' ``` ```python spark.conf.set("spark.sql.crossJoin.enabled", "true") spark.range(1).join(spark.range(1), how="inner").show() ``` ``` +---+---+ | id| id| +---+---+ | 0| 0| +---+---+ ``` ## How was this patch tested? Added tests in `python/pyspark/sql/tests.py`. Author: hyukjinkwon <gurwls223@gmail.com> Closes #18484 from HyukjinKwon/SPARK-21264. --- python/pyspark/sql/dataframe.py | 2 ++ python/pyspark/sql/tests.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 0649271ed2..27a6dad891 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -833,6 +833,8 @@ class DataFrame(object): else: if how is None: how = "inner" + if on is None: + on = self._jseq([]) assert isinstance(how, basestring), "how should be basestring" jdf = self._jdf.join(other._jdf, on, how) return DataFrame(jdf, self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0a1cd6856b..c105969b26 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2021,6 +2021,22 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(df.schema.simpleString(), "struct<value:int>") self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + def test_join_without_on(self): + df1 = self.spark.range(1).toDF("a") + df2 = self.spark.range(1).toDF("b") + + try: + self.spark.conf.set("spark.sql.crossJoin.enabled", "false") + self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect()) + + self.spark.conf.set("spark.sql.crossJoin.enabled", "true") + actual = df1.join(df2, how="inner").collect() + expected = [Row(a=0, b=0)] + self.assertEqual(actual, expected) + finally: + # We should unset this. Otherwise, other tests are affected. + self.spark.conf.unset("spark.sql.crossJoin.enabled") + # Regression test for invalid join methods when on is None, Spark-14761 def test_invalid_join_method(self): df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"]) -- GitLab