diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index f21a39a2d473073d2b928e054ae0415725c5c7b6..2296946cd7c5e4a1764000d302689591a886d5b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -125,12 +125,13 @@ object ExpressionEncoder { } } else { val input = BoundReference(index, enc.schema, nullable = true) - enc.deserializer.transformUp { + val deserialized = enc.deserializer.transformUp { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal) } + If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 2f2323fa3a25f52b4e992c5c76dd620ab9ef25fa..c2e3ab82ff16cfc4fe21c627b17efdb6f243617b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier -import scala.annotation.tailrec import scala.language.existentials import scala.reflect.ClassTag diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3a6ec4595e78e90b9832e91a939bc86694d4be90..369b772d322c01f7b1af4142228f9a5fa29bb1fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -747,31 +747,62 @@ class Dataset[T] private[sql]( */ @Experimental def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { - val left = this.logicalPlan - val right = other.logicalPlan - - val joined = sparkSession.sessionState.executePlan(Join(left, right, joinType = - JoinType(joinType), Some(condition.expr))) - val leftOutput = joined.analyzed.output.take(left.output.length) - val rightOutput = joined.analyzed.output.takeRight(right.output.length) + // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, + // etc. + val joined = sparkSession.sessionState.executePlan( + Join( + this.logicalPlan, + other.logicalPlan, + JoinType(joinType), + Some(condition.expr))).analyzed.asInstanceOf[Join] + + // For both join side, combine all outputs into a single column and alias it with "_1" or "_2", + // to match the schema for the encoder of the join result. + // Note that we do this before joining them, to enable the join operator to return null for one + // side, in cases like outer-join. + val left = { + val combined = if (this.unresolvedTEncoder.flat) { + assert(joined.left.output.length == 1) + Alias(joined.left.output.head, "_1")() + } else { + Alias(CreateStruct(joined.left.output), "_1")() + } + Project(combined :: Nil, joined.left) + } - val leftData = this.unresolvedTEncoder match { - case e if e.flat => Alias(leftOutput.head, "_1")() - case _ => Alias(CreateStruct(leftOutput), "_1")() + val right = { + val combined = if (other.unresolvedTEncoder.flat) { + assert(joined.right.output.length == 1) + Alias(joined.right.output.head, "_2")() + } else { + Alias(CreateStruct(joined.right.output), "_2")() + } + Project(combined :: Nil, joined.right) } - val rightData = other.unresolvedTEncoder match { - case e if e.flat => Alias(rightOutput.head, "_2")() - case _ => Alias(CreateStruct(rightOutput), "_2")() + + // Rewrites the join condition to make the attribute point to correct column/field, after we + // combine the outputs of each join side. + val conditionExpr = joined.condition.get transformUp { + case a: Attribute if joined.left.outputSet.contains(a) => + if (this.unresolvedTEncoder.flat) { + left.output.head + } else { + val index = joined.left.output.indexWhere(_.exprId == a.exprId) + GetStructField(left.output.head, index) + } + case a: Attribute if joined.right.outputSet.contains(a) => + if (other.unresolvedTEncoder.flat) { + right.output.head + } else { + val index = joined.right.output.indexWhere(_.exprId == a.exprId) + GetStructField(right.output.head, index) + } } implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) - withTypedPlan { - Project( - leftData :: rightData :: Nil, - joined.analyzed) - } + withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr))) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8fc4dc9f174660df89ed890c5a99e0e5a83ad652..0b6874e3b8ad32a4697c346055689be09a0661f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -253,21 +253,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (1, 1), (2, 2)) } - test("joinWith, expression condition, outer join") { - val nullInteger = null.asInstanceOf[Integer] - val nullString = null.asInstanceOf[String] - val ds1 = Seq(ClassNullableData("a", 1), - ClassNullableData("c", 3)).toDS() - val ds2 = Seq(("a", new Integer(1)), - ("b", new Integer(2))).toDS() - - checkDataset( - ds1.joinWith(ds2, $"_1" === $"a", "outer"), - (ClassNullableData("a", 1), ("a", new Integer(1))), - (ClassNullableData("c", 3), (nullString, nullInteger)), - (ClassNullableData(nullString, nullInteger), ("b", new Integer(2)))) - } - test("joinWith tuple with primitive, expression") { val ds1 = Seq(1, 1, 2).toDS() val ds2 = Seq(("a", 1), ("b", 2)).toDS() @@ -783,6 +768,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ds.filter(_.b > 1).collect().toSeq } } + + test("SPARK-15441: Dataset outer join") { + val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS().as("left") + val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDS().as("right") + val joined = left.joinWith(right, $"left.b" === $"right.b", "left") + val result = joined.collect().toSet + assert(result == Set(ClassData("a", 1) -> null, ClassData("b", 2) -> ClassData("x", 2))) + } } case class Generic[T](id: T, value: Double)