diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2a4c40db8bb6616a9e6617444ff5e3e9af681f0f..9eea2b0382535a4693d71e0c989c09cad41b34bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} -import org.apache.spark.sql.catalyst.{InternalRow, expressions} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.types.{TimestampType, DateType, StringType, StructType} import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -343,11 +343,17 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { * and convert them. */ protected[sql] def selectFilters(filters: Seq[Expression]) = { + import CatalystTypeConverters._ + def translate(predicate: Expression): Option[Filter] = predicate match { case expressions.EqualTo(a: Attribute, Literal(v, _)) => Some(sources.EqualTo(a.name, v)) case expressions.EqualTo(Literal(v, _), a: Attribute) => Some(sources.EqualTo(a.name, v)) + case expressions.EqualTo(Cast(a: Attribute, _), l: Literal) => + Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.EqualTo(l: Literal, Cast(a: Attribute, _)) => + Some(sources.EqualTo(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) => Some(sources.EqualNullSafe(a.name, v)) @@ -358,21 +364,41 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { Some(sources.GreaterThan(a.name, v)) case expressions.GreaterThan(Literal(v, _), a: Attribute) => Some(sources.LessThan(a.name, v)) + case expressions.GreaterThan(Cast(a: Attribute, _), l: Literal) => + Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.GreaterThan(l: Literal, Cast(a: Attribute, _)) => + Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.LessThan(a: Attribute, Literal(v, _)) => Some(sources.LessThan(a.name, v)) case expressions.LessThan(Literal(v, _), a: Attribute) => Some(sources.GreaterThan(a.name, v)) + case expressions.LessThan(Cast(a: Attribute, _), l: Literal) => + Some(sources.LessThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.LessThan(l: Literal, Cast(a: Attribute, _)) => + Some(sources.GreaterThan(a.name, convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => Some(sources.GreaterThanOrEqual(a.name, v)) case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => Some(sources.LessThanOrEqual(a.name, v)) + case expressions.GreaterThanOrEqual(Cast(a: Attribute, _), l: Literal) => + Some(sources.GreaterThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.GreaterThanOrEqual(l: Literal, Cast(a: Attribute, _)) => + Some(sources.LessThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => Some(sources.LessThanOrEqual(a.name, v)) case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => Some(sources.GreaterThanOrEqual(a.name, v)) + case expressions.LessThanOrEqual(Cast(a: Attribute, _), l: Literal) => + Some(sources.LessThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) + case expressions.LessThanOrEqual(l: Literal, Cast(a: Attribute, _)) => + Some(sources.GreaterThanOrEqual(a.name, + convertToScala(Cast(l, a.dataType).eval(), a.dataType))) case expressions.InSet(a: Attribute, set) => Some(sources.In(a.name, set.toArray)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 8eab6a0adccc474e6426756d289998de33e00a6f..281943e23fcff0bb5a8cb9c4ed20303995f5a7d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -284,7 +284,7 @@ private[sql] class JDBCRDD( /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ - private val filterWhereClause: String = { + val filterWhereClause: String = { val filterStrings = filters map compileFilter filter (_ != null) if (filterStrings.size > 0) { val sb = new StringBuilder("WHERE ") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 42f2449afb0f940a9d2ccbd28ed7c3f79d8cc43e..b9cfae51e809c3fa801ffd2491edb4a17a6142fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,6 +25,8 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD +import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -148,6 +150,18 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement("create table test.decimals (a DECIMAL(7, 2), b DECIMAL(4, 0))"). + executeUpdate() + conn.prepareStatement("insert into test.decimals values (12345.67, 1234)").executeUpdate() + conn.prepareStatement("insert into test.decimals values (34567.89, 1428)").executeUpdate() + conn.commit() + sql( + s""" + |CREATE TEMPORARY TABLE decimals + |USING org.apache.spark.sql.jdbc + |OPTIONS (url '$url', dbtable 'TEST.DECIMALS', user 'testUser', password 'testPass') + """.stripMargin.replaceAll("\n", " ")) + conn.prepareStatement( s""" |create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20), @@ -445,4 +459,24 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } + test("SPARK-9182: filters are not passed through to jdbc source") { + def checkPushedFilter(query: String, filterStr: String): Unit = { + val rddOpt = sql(query).queryExecution.executedPlan.collectFirst { + case PhysicalRDD(_, rdd: JDBCRDD, _) => rdd + } + assert(rddOpt.isDefined) + val pushedFilterStr = rddOpt.get.filterWhereClause + assert(pushedFilterStr.contains(filterStr), + s"Expected to push [$filterStr], actually we pushed [$pushedFilterStr]") + } + + checkPushedFilter("select * from foobar where NAME = 'fred'", "NAME = 'fred'") + checkPushedFilter("select * from inttypes where A > '15'", "A > 15") + checkPushedFilter("select * from inttypes where C <= 20", "C <= 20") + checkPushedFilter("select * from decimals where A > 1000", "A > 1000.00") + checkPushedFilter("select * from decimals where A > 1000 AND A < 2000", + "A > 1000.00 AND A < 2000.00") + checkPushedFilter("select * from decimals where A = 2000 AND B > 20", "A = 2000.00 AND B > 20") + checkPushedFilter("select * from timetypes where B > '1998-09-10'", "B > 1998-09-10") + } }