Skip to content
Snippets Groups Projects
Commit 77e845ca authored by Michael Armbrust's avatar Michael Armbrust Committed by Reynold Xin
Browse files

[SPARK-4394][SQL] Data Sources API Improvements

This PR adds two features to the data sources API:
 - Support for pushing down `IN` filters
 - The ability for relations to optionally provide information about their `sizeInBytes`.

Author: Michael Armbrust <michael@databricks.com>

Closes #3260 from marmbrus/sourcesImprovements and squashes the following commits:

9a5e171 [Michael Armbrust] Use method instead of configuration directly
99c0e6b [Michael Armbrust] Add support for sizeInBytes.
416f167 [Michael Armbrust] Support for IN in data sources API.
2a04ab3 [Michael Armbrust] Simplify implementation of InSet.
parent e421072d
No related branches found
No related tags found
No related merge requests found
Showing
with 32 additions and 15 deletions
...@@ -99,10 +99,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { ...@@ -99,10 +99,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
* Optimized version of In clause, when all filter values of In clause are * Optimized version of In clause, when all filter values of In clause are
* static. * static.
*/ */
case class InSet(value: Expression, hset: HashSet[Any], child: Seq[Expression]) case class InSet(value: Expression, hset: Set[Any])
extends Predicate { extends Predicate {
def children = child def children = value :: Nil
def nullable = true // TODO: Figure out correct nullability semantics of IN. def nullable = true // TODO: Figure out correct nullability semantics of IN.
override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}" override def toString = s"$value INSET ${hset.mkString("(", ",", ")")}"
......
...@@ -289,7 +289,7 @@ object OptimizeIn extends Rule[LogicalPlan] { ...@@ -289,7 +289,7 @@ object OptimizeIn extends Rule[LogicalPlan] {
case q: LogicalPlan => q transformExpressionsDown { case q: LogicalPlan => q transformExpressionsDown {
case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) => case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
val hSet = list.map(e => e.eval(null)) val hSet = list.map(e => e.eval(null))
InSet(v, HashSet() ++ hSet, v +: list) InSet(v, HashSet() ++ hSet)
} }
} }
} }
......
...@@ -158,13 +158,13 @@ class ExpressionEvaluationSuite extends FunSuite { ...@@ -158,13 +158,13 @@ class ExpressionEvaluationSuite extends FunSuite {
val nl = Literal(null) val nl = Literal(null)
val s = Seq(one, two) val s = Seq(one, two)
val nullS = Seq(one, two, null) val nullS = Seq(one, two, null)
checkEvaluation(InSet(one, hS, one +: s), true) checkEvaluation(InSet(one, hS), true)
checkEvaluation(InSet(two, hS, two +: s), true) checkEvaluation(InSet(two, hS), true)
checkEvaluation(InSet(two, nS, two +: nullS), true) checkEvaluation(InSet(two, nS), true)
checkEvaluation(InSet(nl, nS, nl +: nullS), true) checkEvaluation(InSet(nl, nS), true)
checkEvaluation(InSet(three, hS, three +: s), false) checkEvaluation(InSet(three, hS), false)
checkEvaluation(InSet(three, nS, three +: nullS), false) checkEvaluation(InSet(three, nS), false)
checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true) checkEvaluation(InSet(one, hS) && InSet(two, hS), true)
} }
test("MaxOf") { test("MaxOf") {
......
...@@ -52,8 +52,7 @@ class OptimizeInSuite extends PlanTest { ...@@ -52,8 +52,7 @@ class OptimizeInSuite extends PlanTest {
val optimized = Optimize(originalQuery.analyze) val optimized = Optimize(originalQuery.analyze)
val correctAnswer = val correctAnswer =
testRelation testRelation
.where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2, .where(InSet(UnresolvedAttribute("a"), HashSet[Any]()+1+2))
UnresolvedAttribute("a") +: Seq(Literal(1),Literal(2))))
.analyze .analyze
comparePlans(optimized, correctAnswer) comparePlans(optimized, correctAnswer)
......
...@@ -108,5 +108,7 @@ private[sql] object DataSourceStrategy extends Strategy { ...@@ -108,5 +108,7 @@ private[sql] object DataSourceStrategy extends Strategy {
case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v) case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v)
case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v) case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v)
case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray)
} }
} }
...@@ -41,8 +41,7 @@ private[sql] case class LogicalRelation(relation: BaseRelation) ...@@ -41,8 +41,7 @@ private[sql] case class LogicalRelation(relation: BaseRelation)
} }
@transient override lazy val statistics = Statistics( @transient override lazy val statistics = Statistics(
// TODO: Allow datasources to provide statistics as well. sizeInBytes = BigInt(relation.sizeInBytes)
sizeInBytes = BigInt(relation.sqlContext.defaultSizeInBytes)
) )
/** Used to lookup original attribute capitalization */ /** Used to lookup original attribute capitalization */
......
...@@ -24,3 +24,4 @@ case class GreaterThan(attribute: String, value: Any) extends Filter ...@@ -24,3 +24,4 @@ case class GreaterThan(attribute: String, value: Any) extends Filter
case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter
case class LessThan(attribute: String, value: Any) extends Filter case class LessThan(attribute: String, value: Any) extends Filter
case class LessThanOrEqual(attribute: String, value: Any) extends Filter case class LessThanOrEqual(attribute: String, value: Any) extends Filter
case class In(attribute: String, values: Array[Any]) extends Filter
...@@ -18,7 +18,7 @@ package org.apache.spark.sql.sources ...@@ -18,7 +18,7 @@ package org.apache.spark.sql.sources
import org.apache.spark.annotation.DeveloperApi import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext, StructType} import org.apache.spark.sql.{SQLConf, Row, SQLContext, StructType}
import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute} import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute}
/** /**
...@@ -53,6 +53,15 @@ trait RelationProvider { ...@@ -53,6 +53,15 @@ trait RelationProvider {
abstract class BaseRelation { abstract class BaseRelation {
def sqlContext: SQLContext def sqlContext: SQLContext
def schema: StructType def schema: StructType
/**
* Returns an estimated size of this relation in bytes. This information is used by the planner
* to decided when it is safe to broadcast a relation and can be overridden by sources that
* know the size ahead of time. By default, the system will assume that tables are too
* large to broadcast. This method will be called multiple times during query planning
* and thus should not perform expensive operations for each invocation.
*/
def sizeInBytes = sqlContext.defaultSizeInBytes
} }
/** /**
......
...@@ -51,6 +51,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL ...@@ -51,6 +51,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v
case GreaterThan("a", v: Int) => (a: Int) => a > v case GreaterThan("a", v: Int) => (a: Int) => a > v
case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v
case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a)
} }
def eval(a: Int) = !filterFunctions.map(_(a)).contains(false) def eval(a: Int) = !filterFunctions.map(_(a)).contains(false)
...@@ -121,6 +122,10 @@ class FilteredScanSuite extends DataSourceTest { ...@@ -121,6 +122,10 @@ class FilteredScanSuite extends DataSourceTest {
"SELECT * FROM oneToTenFiltered WHERE a = 1", "SELECT * FROM oneToTenFiltered WHERE a = 1",
Seq(1).map(i => Row(i, i * 2)).toSeq) Seq(1).map(i => Row(i, i * 2)).toSeq)
sqlTest(
"SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)",
Seq(1,3,5).map(i => Row(i, i * 2)).toSeq)
sqlTest( sqlTest(
"SELECT * FROM oneToTenFiltered WHERE A = 1", "SELECT * FROM oneToTenFiltered WHERE A = 1",
Seq(1).map(i => Row(i, i * 2)).toSeq) Seq(1).map(i => Row(i, i * 2)).toSeq)
...@@ -150,6 +155,8 @@ class FilteredScanSuite extends DataSourceTest { ...@@ -150,6 +155,8 @@ class FilteredScanSuite extends DataSourceTest {
testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8) testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8)
testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3)
testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0) testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0)
testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10) testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment