Skip to content
Snippets Groups Projects
Commit a1132168 authored by uncleGen's avatar uncleGen Committed by Sean Owen
Browse files

[SPARK-12031][CORE][BUG] Integer overflow when do sampling

Author: uncleGen <hustyugm@gmail.com>

Closes #10023 from uncleGen/1.6-bugfix.
parent f6883bb7
No related branches found
No related tags found
No related merge requests found
......@@ -253,7 +253,7 @@ private[spark] object RangePartitioner {
*/
def sketch[K : ClassTag](
rdd: RDD[K],
sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = {
sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = {
val shift = rdd.id
// val classTagK = classTag[K] // to avoid serializing the entire partitioner object
val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
......@@ -262,7 +262,7 @@ private[spark] object RangePartitioner {
iter, sampleSizePerPartition, seed)
Iterator((idx, n, sample))
}.collect()
val numItems = sketched.map(_._2.toLong).sum
val numItems = sketched.map(_._2).sum
(numItems, sketched)
}
......
......@@ -34,7 +34,7 @@ private[spark] object SamplingUtils {
input: Iterator[T],
k: Int,
seed: Long = Random.nextLong())
: (Array[T], Int) = {
: (Array[T], Long) = {
val reservoir = new Array[T](k)
// Put the first k elements in the reservoir.
var i = 0
......@@ -52,16 +52,17 @@ private[spark] object SamplingUtils {
(trimReservoir, i)
} else {
// If input size > k, continue the sampling process.
var l = i.toLong
val rand = new XORShiftRandom(seed)
while (input.hasNext) {
val item = input.next()
val replacementIndex = rand.nextInt(i)
val replacementIndex = (rand.nextDouble() * l).toLong
if (replacementIndex < k) {
reservoir(replacementIndex) = item
reservoir(replacementIndex.toInt) = item
}
i += 1
l += 1
}
(reservoir, i)
(reservoir, l)
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment