From 77e845ca7726ffee2d6f8e33ea56ec005dde3874 Mon Sep 17 00:00:00 2001
From: Michael Armbrust <michael@databricks.com>
Date: Fri, 14 Nov 2014 12:00:08 -0800
Subject: [PATCH] [SPARK-4394][SQL] Data Sources API Improvements

This PR adds two features to the data sources API:
 - Support for pushing down `IN` filters
 - The ability for relations to optionally provide information about their `sizeInBytes`.

Author: Michael Armbrust <michael@databricks.com>

Closes #3260 from marmbrus/sourcesImprovements and squashes the following commits:

9a5e171 [Michael Armbrust] Use method instead of configuration directly
99c0e6b [Michael Armbrust] Add support for sizeInBytes.
416f167 [Michael Armbrust] Support for IN in data sources API.
2a04ab3 [Michael Armbrust] Simplify implementation of InSet.
---
 .../sql/catalyst/expressions/predicates.scala      |  4 ++--
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |  2 +-
 .../expressions/ExpressionEvaluationSuite.scala    | 14 +++++++-------
 .../sql/catalyst/optimizer/OptimizeInSuite.scala   |  3 +--
 .../spark/sql/sources/DataSourceStrategy.scala     |  2 ++
 .../apache/spark/sql/sources/LogicalRelation.scala |  3 +--
 .../org/apache/spark/sql/sources/filters.scala     |  1 +
 .../org/apache/spark/sql/sources/interfaces.scala  | 11 ++++++++++-
 .../spark/sql/sources/FilteredScanSuite.scala      |  7 +++++++
 9 files changed, 32 insertions(+), 15 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 1e22b2d03c..94b6fb084d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -99,10 +99,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
  * Optimized version of In clause, when all filter values of In clause are
  * static.
  */
-case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression]) 
+case class InSet(value: Expression, hset: Set[Any])
   extends Predicate {
 
-  def children = child
+  def children = value :: Nil
 
   def nullable = true // TODO: Figure out correct nullability semantics of IN.
   override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a4aa322fc5..f164a6c68a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -289,7 +289,7 @@ object OptimizeIn extends Rule[LogicalPlan] {
     case q: LogicalPlan => q transformExpressionsDown {
       case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
           val hSet = list.map(e => e.eval(null))
-          InSet(v, HashSet() ++ hSet, v +: list)
+          InSet(v, HashSet() ++ hSet)
     }
   }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 918996f11d..2f57be94a8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -158,13 +158,13 @@ class ExpressionEvaluationSuite extends FunSuite {
     val nl = Literal(null)
     val s = Seq(one, two)
     val nullS = Seq(one, two, null)
-    checkEvaluation(InSet(one, hS, one +: s), true)
-    checkEvaluation(InSet(two, hS, two +: s), true)
-    checkEvaluation(InSet(two, nS, two +: nullS), true)
-    checkEvaluation(InSet(nl, nS, nl +: nullS), true)
-    checkEvaluation(InSet(three, hS, three +: s), false)
-    checkEvaluation(InSet(three, nS, three +: nullS), false)
-    checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true)
+    checkEvaluation(InSet(one, hS), true)
+    checkEvaluation(InSet(two, hS), true)
+    checkEvaluation(InSet(two, nS), true)
+    checkEvaluation(InSet(nl, nS), true)
+    checkEvaluation(InSet(three, hS), false)
+    checkEvaluation(InSet(three, nS), false)
+    checkEvaluation(InSet(one, hS) && InSet(two, hS), true)
   }
 
   test("MaxOf") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index 97a78ec971..017b180c57 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -52,8 +52,7 @@ class OptimizeInSuite extends PlanTest {
     val optimized = Optimize(originalQuery.analyze)
     val correctAnswer =
       testRelation
-        .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2, 
-            UnresolvedAttribute("a") +: Seq(Literal(1),Literal(2))))
+        .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2))
         .analyze
 
     comparePlans(optimized, correctAnswer)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index 9b8c6a56b9..954e86822d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -108,5 +108,7 @@ private[sql] object DataSourceStrategy extends Strategy {
 
     case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v)
     case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v)
+
+    case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray)
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala
index 82a2cf8402..4d87f6817d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/LogicalRelation.scala
@@ -41,8 +41,7 @@ private[sql] case class LogicalRelation(relation: BaseRelation)
   }
 
   @transient override lazy val statistics = Statistics(
-    // TODO: Allow datasources to provide statistics as well.
-    sizeInBytes = BigInt(relation.sqlContext.defaultSizeInBytes)
+    sizeInBytes = BigInt(relation.sizeInBytes)
   )
 
   /** Used to lookup original attribute capitalization */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
index e72a2aeb8f..4a9fefc12b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
@@ -24,3 +24,4 @@ case class GreaterThan(attribute: String, value: Any) extends Filter
 case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter
 case class LessThan(attribute: String, value: Any) extends Filter
 case class LessThanOrEqual(attribute: String, value: Any) extends Filter
+case class In(attribute: String, values: Array[Any]) extends Filter
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index ac3bf9d8e1..861638b1e9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.sources
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext, StructType}
+import org.apache.spark.sql.{SQLConf, Row, SQLContext, StructType}
 import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute}
 
 /**
@@ -53,6 +53,15 @@ trait RelationProvider {
 abstract class BaseRelation {
   def sqlContext: SQLContext
   def schema: StructType
+
+  /**
+   * Returns an estimated size of this relation in bytes.  This information is used by the planner
+   * to decided when it is safe to broadcast a relation and can be overridden by sources that
+   * know the size ahead of time. By default, the system will assume that tables are too
+   * large to broadcast.  This method will be called multiple times during query planning
+   * and thus should not perform expensive operations for each invocation.
+   */
+  def sizeInBytes = sqlContext.defaultSizeInBytes
 }
 
 /**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 8b2f1591d5..939b3c0c66 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -51,6 +51,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
       case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v
       case GreaterThan("a", v: Int) => (a: Int) => a > v
       case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v
+      case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a)
     }
 
     def eval(a: Int) = !filterFunctions.map(_(a)).contains(false)
@@ -121,6 +122,10 @@ class FilteredScanSuite extends DataSourceTest {
     "SELECT * FROM oneToTenFiltered WHERE a = 1",
     Seq(1).map(i => Row(i, i * 2)).toSeq)
 
+  sqlTest(
+    "SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)",
+    Seq(1,3,5).map(i => Row(i, i * 2)).toSeq)
+
   sqlTest(
     "SELECT * FROM oneToTenFiltered WHERE A = 1",
     Seq(1).map(i => Row(i, i * 2)).toSeq)
@@ -150,6 +155,8 @@ class FilteredScanSuite extends DataSourceTest {
 
   testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8)
 
+  testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3)
+
   testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0)
   testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10)
 
-- 
GitLab