From b14bfc3f8e97479ac5927c071b00ed18f2104c95 Mon Sep 17 00:00:00 2001
From: Dilip Biswal <dbiswal@us.ibm.com>
Date: Wed, 12 Apr 2017 12:18:01 +0800
Subject: [PATCH] [SPARK-19993][SQL] Caching logical plans containing subquery
 expressions does not work.

## What changes were proposed in this pull request?
The sameResult() method does not work when the logical plan contains subquery expressions.

**Before the fix**
```SQL
scala> val ds = spark.sql("select * from s1 where s1.c1 in (select s2.c1 from s2 where s1.c1 = s2.c1)")
ds: org.apache.spark.sql.DataFrame = [c1: int]

scala> ds.cache
res13: ds.type = [c1: int]

scala> spark.sql("select * from s1 where s1.c1 in (select s2.c1 from s2 where s1.c1 = s2.c1)").explain(true)
== Analyzed Logical Plan ==
c1: int
Project [c1#86]
+- Filter c1#86 IN (list#78 [c1#86])
   :  +- Project [c1#87]
   :     +- Filter (outer(c1#86) = c1#87)
   :        +- SubqueryAlias s2
   :           +- Relation[c1#87] parquet
   +- SubqueryAlias s1
      +- Relation[c1#86] parquet

== Optimized Logical Plan ==
Join LeftSemi, ((c1#86 = c1#87) && (c1#86 = c1#87))
:- Relation[c1#86] parquet
+- Relation[c1#87] parquet
```
**Plan after fix**
```SQL
== Analyzed Logical Plan ==
c1: int
Project [c1#22]
+- Filter c1#22 IN (list#14 [c1#22])
   :  +- Project [c1#23]
   :     +- Filter (outer(c1#22) = c1#23)
   :        +- SubqueryAlias s2
   :           +- Relation[c1#23] parquet
   +- SubqueryAlias s1
      +- Relation[c1#22] parquet

== Optimized Logical Plan ==
InMemoryRelation [c1#22], true, 10000, StorageLevel(disk, memory, deserialized, 1 replicas)
   +- *BroadcastHashJoin [c1#1, c1#1], [c1#2, c1#2], LeftSemi, BuildRight
      :- *FileScan parquet default.s1[c1#1] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/dbiswal/mygit/apache/spark/bin/spark-warehouse/s1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<c1:int>
      +- BroadcastExchange HashedRelationBroadcastMode(List((shiftleft(cast(input[0, int, true] as bigint), 32) | (cast(input[0, int, true] as bigint) & 4294967295))))
         +- *FileScan parquet default.s2[c1#2] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/Users/dbiswal/mygit/apache/spark/bin/spark-warehouse/s2], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<c1:int>
```
## How was this patch tested?
New tests are added to CachedTableSuite.

Author: Dilip Biswal <dbiswal@us.ibm.com>

Closes #17330 from dilipbiswal/subquery_cache_final.
---
 .../sql/catalyst/expressions/subquery.scala   |  26 +++-
 .../spark/sql/catalyst/plans/QueryPlan.scala  |  43 +++---
 .../sql/execution/DataSourceScanExec.scala    |   7 +-
 .../apache/spark/sql/CachedTableSuite.scala   | 143 +++++++++++++++++-
 .../hive/execution/HiveTableScanExec.scala    |   5 +-
 5 files changed, 198 insertions(+), 26 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index 59db28d58a..d7b493d521 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -47,7 +47,6 @@ abstract class SubqueryExpression(
     plan: LogicalPlan,
     children: Seq[Expression],
     exprId: ExprId) extends PlanExpression[LogicalPlan] {
-
   override lazy val resolved: Boolean = childrenResolved && plan.resolved
   override lazy val references: AttributeSet =
     if (plan.resolved) super.references -- plan.outputSet else super.references
@@ -59,6 +58,13 @@ abstract class SubqueryExpression(
         children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
     case _ => false
   }
+  def canonicalize(attrs: AttributeSeq): SubqueryExpression = {
+    // Normalize the outer references in the subquery plan.
+    val normalizedPlan = plan.transformAllExpressions {
+      case OuterReference(r) => OuterReference(QueryPlan.normalizeExprId(r, attrs))
+    }
+    withNewPlan(normalizedPlan).canonicalized.asInstanceOf[SubqueryExpression]
+  }
 }
 
 object SubqueryExpression {
@@ -236,6 +242,12 @@ case class ScalarSubquery(
   override def nullable: Boolean = true
   override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan)
   override def toString: String = s"scalar-subquery#${exprId.id} $conditionString"
+  override lazy val canonicalized: Expression = {
+    ScalarSubquery(
+      plan.canonicalized,
+      children.map(_.canonicalized),
+      ExprId(0))
+  }
 }
 
 object ScalarSubquery {
@@ -268,6 +280,12 @@ case class ListQuery(
   override def nullable: Boolean = false
   override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
   override def toString: String = s"list#${exprId.id} $conditionString"
+  override lazy val canonicalized: Expression = {
+    ListQuery(
+      plan.canonicalized,
+      children.map(_.canonicalized),
+      ExprId(0))
+  }
 }
 
 /**
@@ -290,4 +308,10 @@ case class Exists(
   override def nullable: Boolean = false
   override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan)
   override def toString: String = s"exists#${exprId.id} $conditionString"
+  override lazy val canonicalized: Expression = {
+    Exists(
+      plan.canonicalized,
+      children.map(_.canonicalized),
+      ExprId(0))
+  }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 3008e8cb84..2fb65bd435 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -377,7 +377,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
         // As the root of the expression, Alias will always take an arbitrary exprId, we need to
         // normalize that for equality testing, by assigning expr id from 0 incrementally. The
         // alias name doesn't matter and should be erased.
-        Alias(normalizeExprId(a.child), "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated)
+        val normalizedChild = QueryPlan.normalizeExprId(a.child, allAttributes)
+        Alias(normalizedChild, "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated)
 
       case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 =>
         // Top level `AttributeReference` may also be used for output like `Alias`, we should
@@ -385,7 +386,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
         id += 1
         ar.withExprId(ExprId(id))
 
-      case other => normalizeExprId(other)
+      case other => QueryPlan.normalizeExprId(other, allAttributes)
     }.withNewChildren(canonicalizedChildren)
   }
 
@@ -395,23 +396,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
    */
   protected def preCanonicalized: PlanType = this
 
-  /**
-   * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference`
-   * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we
-   * do not use `BindReferences` here as the plan may take the expression as a parameter with type
-   * `Attribute`, and replace it with `BoundReference` will cause error.
-   */
-  protected def normalizeExprId[T <: Expression](e: T, input: AttributeSeq = allAttributes): T = {
-    e.transformUp {
-      case ar: AttributeReference =>
-        val ordinal = input.indexOf(ar.exprId)
-        if (ordinal == -1) {
-          ar
-        } else {
-          ar.withExprId(ExprId(ordinal))
-        }
-    }.canonicalized.asInstanceOf[T]
-  }
 
   /**
    * Returns true when the given query plan will return the same results as this query plan.
@@ -438,3 +422,24 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
    */
   lazy val allAttributes: AttributeSeq = children.flatMap(_.output)
 }
+
+object QueryPlan {
+  /**
+   * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference`
+   * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we
+   * do not use `BindReferences` here as the plan may take the expression as a parameter with type
+   * `Attribute`, and replace it with `BoundReference` will cause error.
+   */
+  def normalizeExprId[T <: Expression](e: T, input: AttributeSeq): T = {
+    e.transformUp {
+      case s: SubqueryExpression => s.canonicalize(input)
+      case ar: AttributeReference =>
+        val ordinal = input.indexOf(ar.exprId)
+        if (ordinal == -1) {
+          ar
+        } else {
+          ar.withExprId(ExprId(ordinal))
+        }
+    }.canonicalized.asInstanceOf[T]
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 3a9132d74a..866fa98533 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
+import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
@@ -516,10 +517,10 @@ case class FileSourceScanExec(
   override lazy val canonicalized: FileSourceScanExec = {
     FileSourceScanExec(
       relation,
-      output.map(normalizeExprId(_, output)),
+      output.map(QueryPlan.normalizeExprId(_, output)),
       requiredSchema,
-      partitionFilters.map(normalizeExprId(_, output)),
-      dataFilters.map(normalizeExprId(_, output)),
+      partitionFilters.map(QueryPlan.normalizeExprId(_, output)),
+      dataFilters.map(QueryPlan.normalizeExprId(_, output)),
       None)
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 7a7d52b214..e66fe97afa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -26,7 +26,7 @@ import org.scalatest.concurrent.Eventually._
 import org.apache.spark.CleanerListener
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
-import org.apache.spark.sql.execution.RDDScanExec
+import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
 import org.apache.spark.sql.execution.columnar._
 import org.apache.spark.sql.execution.exchange.ShuffleExchange
 import org.apache.spark.sql.functions._
@@ -76,6 +76,13 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
     sum
   }
 
+  private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = {
+    plan.collect {
+      case InMemoryTableScanExec(_, _, relation) =>
+        getNumInMemoryTablesRecursively(relation.child) + 1
+    }.sum
+  }
+
   test("withColumn doesn't invalidate cached dataframe") {
     var evalCount = 0
     val myUDF = udf((x: String) => { evalCount += 1; "result" })
@@ -670,4 +677,138 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
       assert(spark.read.parquet(path).filter($"id" > 4).count() == 15)
     }
   }
+
+  test("SPARK-19993 simple subquery caching") {
+    withTempView("t1", "t2") {
+      Seq(1).toDF("c1").createOrReplaceTempView("t1")
+      Seq(2).toDF("c1").createOrReplaceTempView("t2")
+
+      sql(
+        """
+          |SELECT * FROM t1
+          |WHERE
+          |NOT EXISTS (SELECT * FROM t2)
+        """.stripMargin).cache()
+
+      val cachedDs =
+        sql(
+          """
+            |SELECT * FROM t1
+            |WHERE
+            |NOT EXISTS (SELECT * FROM t2)
+          """.stripMargin)
+      assert(getNumInMemoryRelations(cachedDs) == 1)
+
+      // Additional predicate in the subquery plan should cause a cache miss
+      val cachedMissDs =
+      sql(
+        """
+          |SELECT * FROM t1
+          |WHERE
+          |NOT EXISTS (SELECT * FROM t2 where c1 = 0)
+        """.stripMargin)
+      assert(getNumInMemoryRelations(cachedMissDs) == 0)
+    }
+  }
+
+  test("SPARK-19993 subquery caching with correlated predicates") {
+    withTempView("t1", "t2") {
+      Seq(1).toDF("c1").createOrReplaceTempView("t1")
+      Seq(1).toDF("c1").createOrReplaceTempView("t2")
+
+      // Simple correlated predicate in subquery
+      sql(
+        """
+          |SELECT * FROM t1
+          |WHERE
+          |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1)
+        """.stripMargin).cache()
+
+      val cachedDs =
+        sql(
+          """
+            |SELECT * FROM t1
+            |WHERE
+            |t1.c1 in (SELECT t2.c1 FROM t2 where t1.c1 = t2.c1)
+          """.stripMargin)
+      assert(getNumInMemoryRelations(cachedDs) == 1)
+    }
+  }
+
+  test("SPARK-19993 subquery with cached underlying relation") {
+    withTempView("t1") {
+      Seq(1).toDF("c1").createOrReplaceTempView("t1")
+      spark.catalog.cacheTable("t1")
+
+      // underlying table t1 is cached as well as the query that refers to it.
+      val ds =
+      sql(
+        """
+          |SELECT * FROM t1
+          |WHERE
+          |NOT EXISTS (SELECT * FROM t1)
+        """.stripMargin)
+      assert(getNumInMemoryRelations(ds) == 2)
+
+      val cachedDs =
+        sql(
+          """
+            |SELECT * FROM t1
+            |WHERE
+            |NOT EXISTS (SELECT * FROM t1)
+          """.stripMargin).cache()
+      assert(getNumInMemoryTablesRecursively(cachedDs.queryExecution.sparkPlan) == 3)
+    }
+  }
+
+  test("SPARK-19993 nested subquery caching and scalar + predicate subqueris") {
+    withTempView("t1", "t2", "t3", "t4") {
+      Seq(1).toDF("c1").createOrReplaceTempView("t1")
+      Seq(2).toDF("c1").createOrReplaceTempView("t2")
+      Seq(1).toDF("c1").createOrReplaceTempView("t3")
+      Seq(1).toDF("c1").createOrReplaceTempView("t4")
+
+      // Nested predicate subquery
+      sql(
+        """
+          |SELECT * FROM t1
+          |WHERE
+          |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1))
+        """.stripMargin).cache()
+
+      val cachedDs =
+        sql(
+          """
+            |SELECT * FROM t1
+            |WHERE
+            |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1))
+          """.stripMargin)
+      assert(getNumInMemoryRelations(cachedDs) == 1)
+
+      // Scalar subquery and predicate subquery
+      sql(
+        """
+          |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1)
+          |WHERE
+          |c1 = (SELECT max(c1) FROM t2 GROUP BY c1)
+          |OR
+          |EXISTS (SELECT c1 FROM t3)
+          |OR
+          |c1 IN (SELECT c1 FROM t4)
+        """.stripMargin).cache()
+
+      val cachedDs2 =
+        sql(
+          """
+            |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1)
+            |WHERE
+            |c1 = (SELECT max(c1) FROM t2 GROUP BY c1)
+            |OR
+            |EXISTS (SELECT c1 FROM t3)
+            |OR
+            |c1 IN (SELECT c1 FROM t4)
+          """.stripMargin)
+      assert(getNumInMemoryRelations(cachedDs2) == 1)
+    }
+  }
 }
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index fab0d7fa84..666548d1a4 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.catalog.CatalogRelation
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.hive._
@@ -203,9 +204,9 @@ case class HiveTableScanExec(
   override lazy val canonicalized: HiveTableScanExec = {
     val input: AttributeSeq = relation.output
     HiveTableScanExec(
-      requestedAttributes.map(normalizeExprId(_, input)),
+      requestedAttributes.map(QueryPlan.normalizeExprId(_, input)),
       relation.canonicalized.asInstanceOf[CatalogRelation],
-      partitionPruningPred.map(normalizeExprId(_, input)))(sparkSession)
+      partitionPruningPred.map(QueryPlan.normalizeExprId(_, input)))(sparkSession)
   }
 
   override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession)
-- 
GitLab