diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7b30fcc6c531492f4cafed67354cf0ca444d468d..8b87a4e41c23d186eb3f7c5e74afde504a31034d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -46,6 +46,21 @@ trait CheckAnalysis extends PredicateHelper { }).length > 1 } + private def checkLimitClause(limitExpr: Expression): Unit = { + limitExpr match { + case e if !e.foldable => failAnalysis( + "The limit expression must evaluate to a constant value, but got " + + limitExpr.sql) + case e if e.dataType != IntegerType => failAnalysis( + s"The limit expression must be integer type, but got " + + e.dataType.simpleString) + case e if e.eval().asInstanceOf[Int] < 0 => failAnalysis( + "The limit expression must be equal to or greater than 0, but got " + + e.eval().asInstanceOf[Int]) + case e => // OK + } + } + def checkAnalysis(plan: LogicalPlan): Unit = { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. @@ -251,6 +266,10 @@ trait CheckAnalysis extends PredicateHelper { s"but one table has '${firstError.output.length}' columns and another table has " + s"'${s.children.head.output.length}' columns") + case GlobalLimit(limitExpr, _) => checkLimitClause(limitExpr) + + case LocalLimit(limitExpr, _) => checkLimitClause(limitExpr) + case p if p.expressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) => p match { case _: Filter | _: Aggregate | _: Project => // Ok diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 79f9a210a30b5c407eecc0b731f9f542d66e124b..c0e400f61777ff477bb76727a3cba48e46dd9dcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -660,7 +660,13 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN } override lazy val statistics: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum + val sizeInBytes = if (limit == 0) { + // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero + // (product of children). + 1 + } else { + (limit: Long) * output.map(a => a.dataType.defaultSize).sum + } child.statistics.copy(sizeInBytes = sizeInBytes) } } @@ -675,7 +681,13 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo } override lazy val statistics: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum + val sizeInBytes = if (limit == 0) { + // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero + // (product of children). + 1 + } else { + (limit: Long) * output.map(a => a.dataType.defaultSize).sum + } child.statistics.copy(sizeInBytes = sizeInBytes) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index a9cde1e19efc812bad09e72874463dfc8aed9b85..ff112c51697ade4a69b1878133cfb7422d631ca2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -352,6 +352,12 @@ class AnalysisErrorSuite extends AnalysisTest { "Generators are not supported outside the SELECT clause, but got: Sort" :: Nil ) + errorTest( + "num_rows in limit clause must be equal to or greater than 0", + listRelation.limit(-1), + "The limit expression must be equal to or greater than 0, but got -1" :: Nil + ) + errorTest( "more than one generators in SELECT", listRelation.select(Explode('list), Explode('list)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dca9e5e503c72b1a5b694b44911a7997dcce12e3..ede7d9a0c95b98486cef7bc0b82fda3965f53c1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -660,11 +660,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("limit") { checkAnswer( - sql("SELECT * FROM testData LIMIT 10"), + sql("SELECT * FROM testData LIMIT 9 + 1"), testData.take(10).toSeq) checkAnswer( - sql("SELECT * FROM arrayData LIMIT 1"), + sql("SELECT * FROM arrayData LIMIT CAST(1 AS Integer)"), arrayData.collect().take(1).map(Row.fromTuple).toSeq) checkAnswer( @@ -672,6 +672,39 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { mapData.collect().take(1).map(Row.fromTuple).toSeq) } + test("non-foldable expressions in LIMIT") { + val e = intercept[AnalysisException] { + sql("SELECT * FROM testData LIMIT key > 3") + }.getMessage + assert(e.contains("The limit expression must evaluate to a constant value, " + + "but got (testdata.`key` > 3)")) + } + + test("Expressions in limit clause are not integer") { + var e = intercept[AnalysisException] { + sql("SELECT * FROM testData LIMIT true") + }.getMessage + assert(e.contains("The limit expression must be integer type, but got boolean")) + + e = intercept[AnalysisException] { + sql("SELECT * FROM testData LIMIT 'a'") + }.getMessage + assert(e.contains("The limit expression must be integer type, but got string")) + } + + test("negative in LIMIT or TABLESAMPLE") { + val expected = "The limit expression must be equal to or greater than 0, but got -1" + var e = intercept[AnalysisException] { + sql("SELECT * FROM testData TABLESAMPLE (-1 rows)") + }.getMessage + assert(e.contains(expected)) + + e = intercept[AnalysisException] { + sql("SELECT * FROM testData LIMIT -1") + }.getMessage + assert(e.contains(expected)) + } + test("CTE feature") { checkAnswer( sql("with q1 as (select * from testData limit 10) select * from q1"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala index 4de3cf605caa17bc01ad39112980f7f42b75c279..ab55242ec06833c2bc228ab547af3581ff9763d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, Join, LocalLimit} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ class StatisticsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) @@ -31,4 +33,46 @@ class StatisticsSuite extends QueryTest with SharedSQLContext { spark.sessionState.conf.autoBroadcastJoinThreshold) } + test("estimates the size of limit") { + withTempTable("test") { + Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") + .createOrReplaceTempView("test") + Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) => + val df = sql(s"""SELECT * FROM test limit $limit""") + + val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit => + g.statistics.sizeInBytes + } + assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") + assert(sizesGlobalLimit.head === BigInt(expected), + s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}") + + val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit => + l.statistics.sizeInBytes + } + assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") + assert(sizesLocalLimit.head === BigInt(expected), + s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}") + } + } + } + + test("estimates the size of a limit 0 on outer join") { + withTempTable("test") { + Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") + .createOrReplaceTempView("test") + val df1 = spark.table("test") + val df2 = spark.table("test").limit(0) + val df = df1.join(df2, Seq("k"), "left") + + val sizes = df.queryExecution.analyzed.collect { case g: Join => + g.statistics.sizeInBytes + } + + assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") + assert(sizes.head === BigInt(96), + s"expected exact size 96 for table 'test', got: ${sizes.head}") + } + } + }