Skip to content
Snippets Groups Projects
Commit a1132168 authored by uncleGen's avatar uncleGen Committed by Sean Owen
Browse files

[SPARK-12031][CORE][BUG] Integer overflow when do sampling

Author: uncleGen <hustyugm@gmail.com>

Closes #10023 from uncleGen/1.6-bugfix.
parent f6883bb7
No related branches found
No related tags found
No related merge requests found
...@@ -253,7 +253,7 @@ private[spark] object RangePartitioner { ...@@ -253,7 +253,7 @@ private[spark] object RangePartitioner {
*/ */
def sketch[K : ClassTag]( def sketch[K : ClassTag](
rdd: RDD[K], rdd: RDD[K],
sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = {
val shift = rdd.id val shift = rdd.id
// val classTagK = classTag[K] // to avoid serializing the entire partitioner object // val classTagK = classTag[K] // to avoid serializing the entire partitioner object
val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
...@@ -262,7 +262,7 @@ private[spark] object RangePartitioner { ...@@ -262,7 +262,7 @@ private[spark] object RangePartitioner {
iter, sampleSizePerPartition, seed) iter, sampleSizePerPartition, seed)
Iterator((idx, n, sample)) Iterator((idx, n, sample))
}.collect() }.collect()
val numItems = sketched.map(_._2.toLong).sum val numItems = sketched.map(_._2).sum
(numItems, sketched) (numItems, sketched)
} }
......
...@@ -34,7 +34,7 @@ private[spark] object SamplingUtils { ...@@ -34,7 +34,7 @@ private[spark] object SamplingUtils {
input: Iterator[T], input: Iterator[T],
k: Int, k: Int,
seed: Long = Random.nextLong()) seed: Long = Random.nextLong())
: (Array[T], Int) = { : (Array[T], Long) = {
val reservoir = new Array[T](k) val reservoir = new Array[T](k)
// Put the first k elements in the reservoir. // Put the first k elements in the reservoir.
var i = 0 var i = 0
...@@ -52,16 +52,17 @@ private[spark] object SamplingUtils { ...@@ -52,16 +52,17 @@ private[spark] object SamplingUtils {
(trimReservoir, i) (trimReservoir, i)
} else { } else {
// If input size > k, continue the sampling process. // If input size > k, continue the sampling process.
var l = i.toLong
val rand = new XORShiftRandom(seed) val rand = new XORShiftRandom(seed)
while (input.hasNext) { while (input.hasNext) {
val item = input.next() val item = input.next()
val replacementIndex = rand.nextInt(i) val replacementIndex = (rand.nextDouble() * l).toLong
if (replacementIndex < k) { if (replacementIndex < k) {
reservoir(replacementIndex) = item reservoir(replacementIndex.toInt) = item
} }
i += 1 l += 1
} }
(reservoir, i) (reservoir, l)
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment