Skip to content
Snippets Groups Projects
Commit 425ff03f authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-11436] [SQL] rebind right encoder when join 2 datasets

When we join 2 datasets, we will combine 2 encoders into a tupled one, and use it as the encoder for the jioned dataset. Assume both of the 2 encoders are flat, their `constructExpression`s both reference to the first element of input row. However, when we combine 2 encoders, the schema of input row changed,  now the right encoder should reference to second element of input row. So we should rebind right encoder to let it know the new schema of input row before combine it.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9391 from cloud-fan/join and squashes the following commits:

846d3ab [Wenchen Fan] rebind right encoder when join 2 datasets
parent 67e23b39
No related branches found
No related tags found
No related merge requests found
...@@ -390,7 +390,9 @@ class Dataset[T] private( ...@@ -390,7 +390,9 @@ class Dataset[T] private(
val rightEncoder = val rightEncoder =
if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute) if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute)
implicit val tuple2Encoder: Encoder[(T, U)] = implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder.tuple(leftEncoder, rightEncoder) ExpressionEncoder.tuple(
leftEncoder,
rightEncoder.rebind(right.output, left.output ++ right.output))
withPlan[(T, U)](other) { (left, right) => withPlan[(T, U)](other) { (left, right) =>
Project( Project(
......
...@@ -214,4 +214,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ...@@ -214,4 +214,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
cogrouped, cogrouped,
1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er") 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er")
} }
test("SPARK-11436: we should rebind right encoder when join 2 datasets") {
val ds1 = Seq("1", "2").toDS().as("a")
val ds2 = Seq(2, 3).toDS().as("b")
val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
checkAnswer(joined, ("2", 2))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment