diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 96f2e38946f1cad412f913fb6f127cbbba0ec050..d1d2c59caed9a128568a911e4122a89b07b605ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1836,13 +1836,25 @@ class Analyzer( } private def commonNaturalJoinProcessing( - left: LogicalPlan, - right: LogicalPlan, - joinType: JoinType, - joinNames: Seq[String], - condition: Option[Expression]) = { - val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) - val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) + left: LogicalPlan, + right: LogicalPlan, + joinType: JoinType, + joinNames: Seq[String], + condition: Option[Expression]) = { + val leftKeys = joinNames.map { keyName => + val joinColumn = left.output.find(attr => resolver(attr.name, keyName)) + assert( + joinColumn.isDefined, + s"$keyName should exist in ${left.output.map(_.name).mkString(",")}") + joinColumn.get + } + val rightKeys = joinNames.map { keyName => + val joinColumn = right.output.find(attr => resolver(attr.name, keyName)) + assert( + joinColumn.isDefined, + s"$keyName should exist in ${right.output.map(_.name).mkString(",")}") + joinColumn.get + } val joinPairs = leftKeys.zip(rightKeys) val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala index 748579df4158074e68e84a4b76e90ba2845709dd..100ec4d53fb818c1f46dcdd46162dcf3b3279d34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -113,4 +113,34 @@ class ResolveNaturalJoinSuite extends AnalysisTest { assert(error.message.contains( "using columns ['d] can not be resolved given input columns: [b, a, c]")) } + + test("using join with a case sensitive analyzer") { + val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c) + + { + val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("a"))), None) + checkAnalysis(usingPlan, expected, caseSensitive = true) + } + + { + val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("A"))), None) + assertAnalysisError( + usingPlan, + Seq("using columns ['A] can not be resolved given input columns: [b, a, c, a]")) + } + } + + test("using join with a case insensitive analyzer") { + val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c) + + { + val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("a"))), None) + checkAnalysis(usingPlan, expected, caseSensitive = false) + } + + { + val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("A"))), None) + checkAnalysis(usingPlan, expected, caseSensitive = false) + } + } }