From 3320b0ba262159c0c7209ce39b353c93c597077d Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh <viirya@appier.com>
Date: Fri, 31 Jul 2015 22:26:30 -0700
Subject: [PATCH] [SPARK-9415][SQL] Throw AnalysisException when using MapType
 on Join and Aggregate

JIRA: https://issues.apache.org/jira/browse/SPARK-9415

Following up #7787. We shouldn't use MapType as grouping keys and join keys too.

Author: Liang-Chi Hsieh <viirya@appier.com>

Closes #7819 from viirya/map_join_groupby and squashes the following commits:

005ee0c [Liang-Chi Hsieh] For comments.
7463398 [Liang-Chi Hsieh] MapType can't be used as join keys, grouping keys.
---
 .../sql/catalyst/analysis/CheckAnalysis.scala | 16 +++--
 .../analysis/AnalysisErrorSuite.scala         | 68 ++++++++++++++++++-
 .../spark/sql/DataFrameAggregateSuite.scala   | 10 ---
 .../org/apache/spark/sql/JoinSuite.scala      |  8 ---
 4 files changed, 77 insertions(+), 25 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 0ebc3d180a..364569d8f0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -92,8 +92,11 @@ trait CheckAnalysis {
               case p: Predicate =>
                 p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs)
               case e if e.dataType.isInstanceOf[BinaryType] =>
-                failAnalysis(s"expression ${e.prettyString} in join condition " +
-                  s"'${condition.prettyString}' can't be binary type.")
+                failAnalysis(s"binary type expression ${e.prettyString} cannot be used " +
+                  "in join conditions")
+              case e if e.dataType.isInstanceOf[MapType] =>
+                failAnalysis(s"map type expression ${e.prettyString} cannot be used " +
+                  "in join conditions")
               case _ => // OK
             }
 
@@ -114,13 +117,16 @@ trait CheckAnalysis {
 
             def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match {
               case BinaryType =>
-                failAnalysis(s"grouping expression '${expr.prettyString}' in aggregate can " +
-                  s"not be binary type.")
+                failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " +
+                  "in grouping expression")
+              case m: MapType =>
+                failAnalysis(s"map type expression ${expr.prettyString} cannot be used " +
+                  "in grouping expression")
               case _ => // OK
             }
 
             aggregateExprs.foreach(checkValidAggregateExpression)
-            aggregateExprs.foreach(checkValidGroupingExprs)
+            groupingExprs.foreach(checkValidGroupingExprs)
 
           case Sort(orders, _, _) =>
             orders.foreach { order =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 2588df9824..aa19cdce31 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -181,7 +181,71 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
     val error = intercept[AnalysisException] {
       SimpleAnalyzer.checkAnalysis(join)
     }
-    error.message.contains("Failure when resolving conflicting references in Join")
-    error.message.contains("Conflicting attributes")
+    assert(error.message.contains("Failure when resolving conflicting references in Join"))
+    assert(error.message.contains("Conflicting attributes"))
+  }
+
+  test("aggregation can't work on binary and map types") {
+    val plan =
+      Aggregate(
+        AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil,
+        Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+        LocalRelation(
+          AttributeReference("a", BinaryType)(exprId = ExprId(2)),
+          AttributeReference("b", IntegerType)(exprId = ExprId(1))))
+
+    val error = intercept[AnalysisException] {
+      caseSensitiveAnalyze(plan)
+    }
+    assert(error.message.contains("binary type expression a cannot be used in grouping expression"))
+
+    val plan2 =
+      Aggregate(
+        AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil,
+        Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil,
+        LocalRelation(
+          AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
+          AttributeReference("b", IntegerType)(exprId = ExprId(1))))
+
+    val error2 = intercept[AnalysisException] {
+      caseSensitiveAnalyze(plan2)
+    }
+    assert(error2.message.contains("map type expression a cannot be used in grouping expression"))
+  }
+
+  test("Join can't work on binary and map types") {
+    val plan =
+      Join(
+        LocalRelation(
+          AttributeReference("a", BinaryType)(exprId = ExprId(2)),
+          AttributeReference("b", IntegerType)(exprId = ExprId(1))),
+        LocalRelation(
+          AttributeReference("c", BinaryType)(exprId = ExprId(4)),
+          AttributeReference("d", IntegerType)(exprId = ExprId(3))),
+        Inner,
+        Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)),
+          AttributeReference("c", BinaryType)(exprId = ExprId(4)))))
+
+    val error = intercept[AnalysisException] {
+      caseSensitiveAnalyze(plan)
+    }
+    assert(error.message.contains("binary type expression a cannot be used in join conditions"))
+
+    val plan2 =
+      Join(
+        LocalRelation(
+          AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
+          AttributeReference("b", IntegerType)(exprId = ExprId(1))),
+        LocalRelation(
+          AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)),
+          AttributeReference("d", IntegerType)(exprId = ExprId(3))),
+        Inner,
+        Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
+          AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)))))
+
+    val error2 = intercept[AnalysisException] {
+      caseSensitiveAnalyze(plan2)
+    }
+    assert(error2.message.contains("map type expression a cannot be used in join conditions"))
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 228ece8065..f9cff7440a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -190,14 +190,4 @@ class DataFrameAggregateSuite extends QueryTest {
       emptyTableData.agg(sumDistinct('a)),
       Row(null))
   }
-
-  test("aggregation can't work on binary type") {
-    val df = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType)
-    intercept[AnalysisException] {
-      df.groupBy("c").agg(count("*"))
-    }
-    intercept[AnalysisException] {
-      df.distinct
-    }
-  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 27c08f6464..5bef1d8966 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -490,12 +490,4 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
         Row(3, 2) :: Nil)
 
   }
-
-  test("Join can't work on binary type") {
-    val left = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType)
-    val right = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("d").select($"d" cast BinaryType)
-    intercept[AnalysisException] {
-      left.join(right, ($"left.N" === $"right.N"), "full")
-    }
-  }
 }
-- 
GitLab