From 8640cdb836b4964e4af891d9959af64a2e1f304e Mon Sep 17 00:00:00 2001 From: Wenchen Fan <wenchen@databricks.com> Date: Wed, 1 Jun 2016 16:16:54 -0700 Subject: [PATCH] [SPARK-15441][SQL] support null object in Dataset outer-join ## What changes were proposed in this pull request? Currently we can't encode top level null object into internal row, as Spark SQL doesn't allow row to be null, only its columns can be null. This is not a problem before, as we assume the input object is never null. However, for outer join, we do need the semantics of null object. This PR fixes this problem by making both join sides produce a single column, i.e. nest the logical plan output(by `CreateStruct`), so that we have an extra level to represent top level null obejct. ## How was this patch tested? new test in `DatasetSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #13425 from cloud-fan/outer-join2. --- .../catalyst/encoders/ExpressionEncoder.scala | 3 +- .../expressions/objects/objects.scala | 1 - .../scala/org/apache/spark/sql/Dataset.scala | 67 ++++++++++++++----- .../org/apache/spark/sql/DatasetSuite.scala | 23 +++---- 4 files changed, 59 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index f21a39a2d4..2296946cd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -125,12 +125,13 @@ object ExpressionEncoder { } } else { val input = BoundReference(index, enc.schema, nullable = true) - enc.deserializer.transformUp { + val deserialized = enc.deserializer.transformUp { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal) } + If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 2f2323fa3a..c2e3ab82ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier -import scala.annotation.tailrec import scala.language.existentials import scala.reflect.ClassTag diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3a6ec4595e..369b772d32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -747,31 +747,62 @@ class Dataset[T] private[sql]( */ @Experimental def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { - val left = this.logicalPlan - val right = other.logicalPlan - - val joined = sparkSession.sessionState.executePlan(Join(left, right, joinType = - JoinType(joinType), Some(condition.expr))) - val leftOutput = joined.analyzed.output.take(left.output.length) - val rightOutput = joined.analyzed.output.takeRight(right.output.length) + // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, + // etc. + val joined = sparkSession.sessionState.executePlan( + Join( + this.logicalPlan, + other.logicalPlan, + JoinType(joinType), + Some(condition.expr))).analyzed.asInstanceOf[Join] + + // For both join side, combine all outputs into a single column and alias it with "_1" or "_2", + // to match the schema for the encoder of the join result. + // Note that we do this before joining them, to enable the join operator to return null for one + // side, in cases like outer-join. + val left = { + val combined = if (this.unresolvedTEncoder.flat) { + assert(joined.left.output.length == 1) + Alias(joined.left.output.head, "_1")() + } else { + Alias(CreateStruct(joined.left.output), "_1")() + } + Project(combined :: Nil, joined.left) + } - val leftData = this.unresolvedTEncoder match { - case e if e.flat => Alias(leftOutput.head, "_1")() - case _ => Alias(CreateStruct(leftOutput), "_1")() + val right = { + val combined = if (other.unresolvedTEncoder.flat) { + assert(joined.right.output.length == 1) + Alias(joined.right.output.head, "_2")() + } else { + Alias(CreateStruct(joined.right.output), "_2")() + } + Project(combined :: Nil, joined.right) } - val rightData = other.unresolvedTEncoder match { - case e if e.flat => Alias(rightOutput.head, "_2")() - case _ => Alias(CreateStruct(rightOutput), "_2")() + + // Rewrites the join condition to make the attribute point to correct column/field, after we + // combine the outputs of each join side. + val conditionExpr = joined.condition.get transformUp { + case a: Attribute if joined.left.outputSet.contains(a) => + if (this.unresolvedTEncoder.flat) { + left.output.head + } else { + val index = joined.left.output.indexWhere(_.exprId == a.exprId) + GetStructField(left.output.head, index) + } + case a: Attribute if joined.right.outputSet.contains(a) => + if (other.unresolvedTEncoder.flat) { + right.output.head + } else { + val index = joined.right.output.indexWhere(_.exprId == a.exprId) + GetStructField(right.output.head, index) + } } implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) - withTypedPlan { - Project( - leftData :: rightData :: Nil, - joined.analyzed) - } + withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr))) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8fc4dc9f17..0b6874e3b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -253,21 +253,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (1, 1), (2, 2)) } - test("joinWith, expression condition, outer join") { - val nullInteger = null.asInstanceOf[Integer] - val nullString = null.asInstanceOf[String] - val ds1 = Seq(ClassNullableData("a", 1), - ClassNullableData("c", 3)).toDS() - val ds2 = Seq(("a", new Integer(1)), - ("b", new Integer(2))).toDS() - - checkDataset( - ds1.joinWith(ds2, $"_1" === $"a", "outer"), - (ClassNullableData("a", 1), ("a", new Integer(1))), - (ClassNullableData("c", 3), (nullString, nullInteger)), - (ClassNullableData(nullString, nullInteger), ("b", new Integer(2)))) - } - test("joinWith tuple with primitive, expression") { val ds1 = Seq(1, 1, 2).toDS() val ds2 = Seq(("a", 1), ("b", 2)).toDS() @@ -783,6 +768,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ds.filter(_.b > 1).collect().toSeq } } + + test("SPARK-15441: Dataset outer join") { + val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS().as("left") + val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDS().as("right") + val joined = left.joinWith(right, $"left.b" === $"right.b", "left") + val result = joined.collect().toSet + assert(result == Set(ClassData("a", 1) -> null, ClassData("b", 2) -> ClassData("x", 2))) + } } case class Generic[T](id: T, value: Double) -- GitLab