Skip to content
Snippets Groups Projects
Commit fb0894b3 authored by Tejas Patil's avatar Tejas Patil Committed by Reynold Xin
Browse files

[SPARK-17698][SQL] Join predicates should not contain filter clauses

## What changes were proposed in this pull request?

Jira : https://issues.apache.org/jira/browse/SPARK-17698

`ExtractEquiJoinKeys` is incorrectly using filter predicates as the join condition for joins. `canEvaluate` [0] tries to see if the an `Expression` can be evaluated using output of a given `Plan`. In case of filter predicates (eg. `a.id='1'`), the `Expression` passed for the right hand side (ie. '1' ) is a `Literal` which does not have any attribute references. Thus `expr.references` is an empty set which theoretically is a subset of any set. This leads to `canEvaluate` returning `true` and `a.id='1'` is treated as a join predicate. While this does not lead to incorrect results but in case of bucketed + sorted tables, we might miss out on avoiding un-necessary shuffle + sort. See example below:

[0] : https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala#L91

eg.

```
val df = (1 until 10).toDF("id").coalesce(1)
hc.sql("DROP TABLE IF EXISTS table1").collect
df.write.bucketBy(8, "id").sortBy("id").saveAsTable("table1")
hc.sql("DROP TABLE IF EXISTS table2").collect
df.write.bucketBy(8, "id").sortBy("id").saveAsTable("table2")

sqlContext.sql("""
  SELECT a.id, b.id
  FROM table1 a
  FULL OUTER JOIN table2 b
  ON a.id = b.id AND a.id='1' AND b.id='1'
""").explain(true)
```

BEFORE: This is doing shuffle + sort over table scan outputs which is not needed as both tables are bucketed and sorted on the same columns and have same number of buckets. This should be a single stage job.

```
SortMergeJoin [id#38, cast(id#38 as double), 1.0], [id#39, 1.0, cast(id#39 as double)], FullOuter
:- *Sort [id#38 ASC NULLS FIRST, cast(id#38 as double) ASC NULLS FIRST, 1.0 ASC NULLS FIRST], false, 0
:  +- Exchange hashpartitioning(id#38, cast(id#38 as double), 1.0, 200)
:     +- *FileScan parquet default.table1[id#38] Batched: true, Format: ParquetFormat, InputPaths: file:spark-warehouse/table1, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int>
+- *Sort [id#39 ASC NULLS FIRST, 1.0 ASC NULLS FIRST, cast(id#39 as double) ASC NULLS FIRST], false, 0
   +- Exchange hashpartitioning(id#39, 1.0, cast(id#39 as double), 200)
      +- *FileScan parquet default.table2[id#39] Batched: true, Format: ParquetFormat, InputPaths: file:spark-warehouse/table2, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int>
```

AFTER :

```
SortMergeJoin [id#32], [id#33], FullOuter, ((cast(id#32 as double) = 1.0) && (cast(id#33 as double) = 1.0))
:- *FileScan parquet default.table1[id#32] Batched: true, Format: ParquetFormat, InputPaths: file:spark-warehouse/table1, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int>
+- *FileScan parquet default.table2[id#33] Batched: true, Format: ParquetFormat, InputPaths: file:spark-warehouse/table2, PartitionFilters: [], PushedFilters: [], ReadSchema: struct<id:int>
```

## How was this patch tested?

- Added a new test case for this scenario : `SPARK-17698 Join predicates should not contain filter clauses`
- Ran all the tests in `BucketedReadSuite`

Author: Tejas Patil <tejasp@fb.com>

Closes #15272 from tejasapatil/SPARK-17698_join_predicate_filter_clause.
parent e895bc25
No related branches found
No related tags found
No related merge requests found
...@@ -84,8 +84,9 @@ trait PredicateHelper { ...@@ -84,8 +84,9 @@ trait PredicateHelper {
* *
* For example consider a join between two relations R(a, b) and S(c, d). * For example consider a join between two relations R(a, b) and S(c, d).
* *
* `canEvaluate(EqualTo(a,b), R)` returns `true` where as `canEvaluate(EqualTo(a,c), R)` returns * - `canEvaluate(EqualTo(a,b), R)` returns `true`
* `false`. * - `canEvaluate(EqualTo(a,c), R)` returns `false`
* - `canEvaluate(Literal(1), R)` returns `true` as literals CAN be evaluated on any plan
*/ */
protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean = protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean =
expr.references.subsetOf(plan.outputSet) expr.references.subsetOf(plan.outputSet)
......
...@@ -65,7 +65,9 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { ...@@ -65,7 +65,9 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
val conditionalJoin = rest.find { planJoinPair => val conditionalJoin = rest.find { planJoinPair =>
val plan = planJoinPair._1 val plan = planJoinPair._1
val refs = left.outputSet ++ plan.outputSet val refs = left.outputSet ++ plan.outputSet
conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan)) conditions
.filterNot(l => l.references.nonEmpty && canEvaluate(l, left))
.filterNot(r => r.references.nonEmpty && canEvaluate(r, plan))
.exists(_.references.subsetOf(refs)) .exists(_.references.subsetOf(refs))
} }
// pick the next one if no condition left // pick the next one if no condition left
......
...@@ -112,6 +112,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { ...@@ -112,6 +112,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
// as join keys. // as join keys.
val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil) val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil)
val joinKeys = predicates.flatMap { val joinKeys = predicates.flatMap {
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => None
case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r)) case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r))
case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l)) case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l))
// Replace null with default value for joining key, then those rows with null in it could // Replace null with default value for joining key, then those rows with null in it could
...@@ -125,6 +126,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { ...@@ -125,6 +126,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
case other => None case other => None
} }
val otherPredicates = predicates.filterNot { val otherPredicates = predicates.filterNot {
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false
case EqualTo(l, r) => case EqualTo(l, r) =>
canEvaluate(l, left) && canEvaluate(r, right) || canEvaluate(l, left) && canEvaluate(r, right) ||
canEvaluate(l, right) && canEvaluate(r, left) canEvaluate(l, right) && canEvaluate(r, left)
......
...@@ -235,7 +235,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet ...@@ -235,7 +235,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
private def testBucketing( private def testBucketing(
bucketSpecLeft: Option[BucketSpec], bucketSpecLeft: Option[BucketSpec],
bucketSpecRight: Option[BucketSpec], bucketSpecRight: Option[BucketSpec],
joinColumns: Seq[String], joinType: String = "inner",
joinCondition: (DataFrame, DataFrame) => Column,
shuffleLeft: Boolean, shuffleLeft: Boolean,
shuffleRight: Boolean, shuffleRight: Boolean,
sortLeft: Boolean = true, sortLeft: Boolean = true,
...@@ -268,12 +269,12 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet ...@@ -268,12 +269,12 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
val t1 = spark.table("bucketed_table1") val t1 = spark.table("bucketed_table1")
val t2 = spark.table("bucketed_table2") val t2 = spark.table("bucketed_table2")
val joined = t1.join(t2, joinCondition(t1, t2, joinColumns)) val joined = t1.join(t2, joinCondition(t1, t2), joinType)
// First check the result is corrected. // First check the result is corrected.
checkAnswer( checkAnswer(
joined.sort("bucketed_table1.k", "bucketed_table2.k"), joined.sort("bucketed_table1.k", "bucketed_table2.k"),
df1.join(df2, joinCondition(df1, df2, joinColumns)).sort("df1.k", "df2.k")) df1.join(df2, joinCondition(df1, df2), joinType).sort("df1.k", "df2.k"))
assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoinExec]) assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoinExec])
val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoinExec] val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoinExec]
...@@ -297,56 +298,102 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet ...@@ -297,56 +298,102 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
} }
} }
private def joinCondition(left: DataFrame, right: DataFrame, joinCols: Seq[String]): Column = { private def joinCondition(joinCols: Seq[String]) (left: DataFrame, right: DataFrame): Column = {
joinCols.map(col => left(col) === right(col)).reduce(_ && _) joinCols.map(col => left(col) === right(col)).reduce(_ && _)
} }
test("avoid shuffle when join 2 bucketed tables") { test("avoid shuffle when join 2 bucketed tables") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = false
)
} }
// Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704 // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
ignore("avoid shuffle when join keys are a super-set of bucket keys") { ignore("avoid shuffle when join keys are a super-set of bucket keys") {
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = false
)
} }
test("only shuffle one side when join bucketed table and non-bucketed table") { test("only shuffle one side when join bucketed table and non-bucketed table") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = None,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = true
)
} }
test("only shuffle one side when 2 bucketed tables have different bucket number") { test("only shuffle one side when 2 bucketed tables have different bucket number") {
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil)) val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil))
val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil)) val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil))
testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) testBucketing(
bucketSpecLeft = bucketSpec1,
bucketSpecRight = bucketSpec2,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = true
)
} }
test("only shuffle one side when 2 bucketed tables have different bucket keys") { test("only shuffle one side when 2 bucketed tables have different bucket keys") {
val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil)) val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil))
val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil)) val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil))
testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true) testBucketing(
bucketSpecLeft = bucketSpec1,
bucketSpecRight = bucketSpec2,
joinCondition = joinCondition(Seq("i")),
shuffleLeft = false,
shuffleRight = true
)
} }
test("shuffle when join keys are not equal to bucket keys") { test("shuffle when join keys are not equal to bucket keys") {
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true) testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinCondition = joinCondition(Seq("j")),
shuffleLeft = true,
shuffleRight = true
)
} }
test("shuffle when join 2 bucketed tables with bucketing disabled") { test("shuffle when join 2 bucketed tables with bucketing disabled") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true) testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = true,
shuffleRight = true
)
} }
} }
test("avoid shuffle and sort when bucket and sort columns are join keys") { test("avoid shuffle and sort when bucket and sort columns are join keys") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
testBucketing( testBucketing(
bucketSpec, bucketSpec, Seq("i", "j"), bucketSpecLeft = bucketSpec,
shuffleLeft = false, shuffleRight = false, bucketSpecRight = bucketSpec,
sortLeft = false, sortRight = false joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = false,
sortLeft = false,
sortRight = false
) )
} }
...@@ -354,9 +401,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet ...@@ -354,9 +401,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Seq("i", "j"))) val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Seq("i", "j")))
val bucketSpec2 = Some(BucketSpec(8, Seq("i"), Seq("i", "k"))) val bucketSpec2 = Some(BucketSpec(8, Seq("i"), Seq("i", "k")))
testBucketing( testBucketing(
bucketSpec1, bucketSpec2, Seq("i"), bucketSpecLeft = bucketSpec1,
shuffleLeft = false, shuffleRight = false, bucketSpecRight = bucketSpec2,
sortLeft = false, sortRight = false joinCondition = joinCondition(Seq("i")),
shuffleLeft = false,
shuffleRight = false,
sortLeft = false,
sortRight = false
) )
} }
...@@ -364,9 +415,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet ...@@ -364,9 +415,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("k"))) val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("k")))
testBucketing( testBucketing(
bucketSpec1, bucketSpec2, Seq("i", "j"), bucketSpecLeft = bucketSpec1,
shuffleLeft = false, shuffleRight = false, bucketSpecRight = bucketSpec2,
sortLeft = false, sortRight = true joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = false,
sortLeft = false,
sortRight = true
) )
} }
...@@ -374,9 +429,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet ...@@ -374,9 +429,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("j", "i"))) val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("j", "i")))
testBucketing( testBucketing(
bucketSpec1, bucketSpec2, Seq("i", "j"), bucketSpecLeft = bucketSpec1,
shuffleLeft = false, shuffleRight = false, bucketSpecRight = bucketSpec2,
sortLeft = false, sortRight = true joinCondition = joinCondition(Seq("i", "j")),
shuffleLeft = false,
shuffleRight = false,
sortLeft = false,
sortRight = true
) )
} }
...@@ -408,6 +467,25 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet ...@@ -408,6 +467,25 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
} }
} }
test("SPARK-17698 Join predicates should not contain filter clauses") {
val bucketSpec = Some(BucketSpec(8, Seq("i"), Seq("i")))
testBucketing(
bucketSpecLeft = bucketSpec,
bucketSpecRight = bucketSpec,
joinType = "fullouter",
joinCondition = (left: DataFrame, right: DataFrame) => {
val joinPredicates = Seq("i").map(col => left(col) === right(col)).reduce(_ && _)
val filterLeft = left("i") === Literal("1")
val filterRight = right("i") === Literal("1")
joinPredicates && filterLeft && filterRight
},
shuffleLeft = false,
shuffleRight = false,
sortLeft = false,
sortRight = false
)
}
test("error if there exists any malformed bucket files") { test("error if there exists any malformed bucket files") {
withTable("bucketed_table") { withTable("bucketed_table") {
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment