From 5d871ea43efdde59e05896a50d57021388412d30 Mon Sep 17 00:00:00 2001 From: QiangCai <david.caiq@gmail.com> Date: Wed, 6 Jan 2016 18:13:07 +0900 Subject: [PATCH] [SPARK-12340][SQL] fix Int overflow in the SparkPlan.executeTake, RDD.take and AsyncRDDActions.takeAsync I have closed pull request https://github.com/apache/spark/pull/10487. And I create this pull request to resolve the problem. spark jira https://issues.apache.org/jira/browse/SPARK-12340 Author: QiangCai <david.caiq@gmail.com> Closes #10562 from QiangCai/bugfix. --- .../scala/org/apache/spark/rdd/AsyncRDDActions.scala | 12 ++++++------ core/src/main/scala/org/apache/spark/rdd/RDD.scala | 8 ++++---- .../org/apache/spark/sql/execution/SparkPlan.scala | 8 ++++---- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 12 ++++++++++++ 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index ec48925823..94719a4572 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -68,7 +68,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val localProperties = self.context.getLocalProperties // Cached thread pool to handle aggregation of subtasks. implicit val executionContext = AsyncRDDActions.futureExecutionContext - val results = new ArrayBuffer[T](num) + val results = new ArrayBuffer[T] val totalParts = self.partitions.length /* @@ -77,13 +77,13 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi This implementation is non-blocking, asynchronously handling the results of each job and triggering the next job using callbacks on futures. */ - def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = + def continue(partsScanned: Long)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = if (results.size >= num || partsScanned >= totalParts) { Future.successful(results.toSeq) } else { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate it @@ -99,7 +99,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } val left = num - results.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt val buf = new Array[Array[T]](p.size) self.context.setCallSite(callSite) @@ -111,11 +111,11 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi Unit) job.flatMap {_ => buf.foreach(results ++= _.take(num - results.size)) - continue(partsScanned + numPartsToTry) + continue(partsScanned + p.size) } } - new ComplexFutureAction[Seq[T]](continue(0)(_)) + new ComplexFutureAction[Seq[T]](continue(0L)(_)) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index d6eac7888d..e25657cc10 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1190,11 +1190,11 @@ abstract class RDD[T: ClassTag]( } else { val buf = new ArrayBuffer[T] val totalParts = this.partitions.length - var partsScanned = 0 + var partsScanned = 0L while (buf.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate @@ -1209,11 +1209,11 @@ abstract class RDD[T: ClassTag]( } val left = num - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(num - buf.size)) - partsScanned += numPartsToTry + partsScanned += p.size } buf.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index f20f32aace..21a6fba907 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -165,11 +165,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length - var partsScanned = 0 + var partsScanned = 0L while (buf.size < n && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the first iteration, just try all partitions next. // Otherwise, interpolate the number of partitions we need to try, but overestimate it @@ -183,13 +183,13 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = n - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.toInt until math.min(partsScanned + numPartsToTry, totalParts).toInt val sc = sqlContext.sparkContext val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(n - buf.size)) - partsScanned += numPartsToTry + partsScanned += p.size } buf.toArray diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5de0979606..bd987ae1bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2067,4 +2067,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } } + + test("SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake") { + val rdd = sqlContext.sparkContext.parallelize(1 to 3 , 3 ) + rdd.toDF("key").registerTempTable("spark12340") + checkAnswer( + sql("select key from spark12340 limit 2147483638"), + Row(1) :: Row(2) :: Row(3) :: Nil + ) + assert(rdd.take(2147483638).size === 3) + assert(rdd.takeAsync(2147483638).get.size === 3) + } + } -- GitLab