Skip to content
Snippets Groups Projects
Commit 65de73c7 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #185 from mkolod/random-number-generator

XORShift RNG with unit tests and benchmark

This patch was introduced to address SPARK-950 - the discussion below the ticket explains not only the rationale, but also the design and testing decisions: https://spark-project.atlassian.net/browse/SPARK-950

To run unit test, start SBT console and type:
compile
test-only org.apache.spark.util.XORShiftRandomSuite
To run benchmark, type:
project core
console
Once the Scala console starts, type:
org.apache.spark.util.XORShiftRandom.benchmark(100000000)
XORShiftRandom is also an object with a main method taking the
number of iterations as an argument, so you can also run it
from the command line.
parents 972171b9 22724659
No related branches found
No related tags found
No related merge requests found
......@@ -823,4 +823,28 @@ private[spark] object Utils extends Logging {
return System.getProperties().clone()
.asInstanceOf[java.util.Properties].toMap[String, String]
}
/**
* Method executed for repeating a task for side effects.
* Unlike a for comprehension, it permits JVM JIT optimization
*/
def times(numIters: Int)(f: => Unit): Unit = {
var i = 0
while (i < numIters) {
f
i += 1
}
}
/**
* Timing method based on iterations that permit JVM JIT optimization.
* @param numIters number of iterations
* @param f function to be executed
*/
def timeIt(numIters: Int)(f: => Unit): Long = {
val start = System.currentTimeMillis
times(numIters)(f)
System.currentTimeMillis - start
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import java.util.{Random => JavaRandom}
import org.apache.spark.util.Utils.timeIt
/**
* This class implements a XORShift random number generator algorithm
* Source:
* Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14.
* @see <a href="http://www.jstatsoft.org/v08/i14/paper">Paper</a>
* This implementation is approximately 3.5 times faster than
* {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due
* to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class
* uses a regular Long. We can forgo thread safety since we use a new instance of the RNG
* for each thread.
*/
private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) {
def this() = this(System.nanoTime)
private var seed = init
// we need to just override next - this will be called by nextInt, nextDouble,
// nextGaussian, nextLong, etc.
override protected def next(bits: Int): Int = {
var nextSeed = seed ^ (seed << 21)
nextSeed ^= (nextSeed >>> 35)
nextSeed ^= (nextSeed << 4)
seed = nextSeed
(nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
}
}
/** Contains benchmark method and main method to run benchmark of the RNG */
private[spark] object XORShiftRandom {
/**
* Main method for running benchmark
* @param args takes one argument - the number of random numbers to generate
*/
def main(args: Array[String]): Unit = {
if (args.length != 1) {
println("Benchmark of XORShiftRandom vis-a-vis java.util.Random")
println("Usage: XORShiftRandom number_of_random_numbers_to_generate")
System.exit(1)
}
println(benchmark(args(0).toInt))
}
/**
* @param numIters Number of random numbers to generate while running the benchmark
* @return Map of execution times for {@link java.util.Random java.util.Random}
* and XORShift
*/
def benchmark(numIters: Int) = {
val seed = 1L
val million = 1e6.toInt
val javaRand = new JavaRandom(seed)
val xorRand = new XORShiftRandom(seed)
// this is just to warm up the JIT - we're not timing anything
timeIt(1e6.toInt) {
javaRand.nextInt()
xorRand.nextInt()
}
val iters = timeIt(numIters)(_)
/* Return results as a map instead of just printing to screen
in case the user wants to do something with them */
Map("javaTime" -> iters {javaRand.nextInt()},
"xorTime" -> iters {xorRand.nextInt()})
}
}
\ No newline at end of file
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import java.util.Random
import org.scalatest.FlatSpec
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.util.Utils.times
class XORShiftRandomSuite extends FunSuite with ShouldMatchers {
def fixture = new {
val seed = 1L
val xorRand = new XORShiftRandom(seed)
val hundMil = 1e8.toInt
}
/*
* This test is based on a chi-squared test for randomness. The values are hard-coded
* so as not to create Spark's dependency on apache.commons.math3 just to call one
* method for calculating the exact p-value for a given number of random numbers
* and bins. In case one would want to move to a full-fledged test based on
* apache.commons.math3, the relevant class is here:
* org.apache.commons.math3.stat.inference.ChiSquareTest
*/
test ("XORShift generates valid random numbers") {
val f = fixture
val numBins = 10
// create 10 bins
val bins = Array.fill(numBins)(0)
// populate bins based on modulus of the random number
times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1}
/* since the seed is deterministic, until the algorithm is changed, we know the result will be
* exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
* 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%)
* significance level. However, should the RNG implementation change, the test should still
* pass at the same significance level. The chi-squared test done in R gave the following
* results:
* > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
* 10000790, 10002286, 9998699))
* Chi-squared test for given probabilities
* data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790,
* 10002286, 9998699)
* X-squared = 11.975, df = 9, p-value = 0.2147
* Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million
* random numbers
* and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared
* is greater than or equal to that number.
*/
val binSize = f.hundMil/numBins
val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum
xSquared should be < (16.9196)
}
}
\ No newline at end of file
......@@ -18,15 +18,16 @@
package org.apache.spark.mllib.clustering
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import org.jblas.DoubleMatrix
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.util.XORShiftRandom
import org.jblas.DoubleMatrix
/**
......@@ -195,7 +196,7 @@ class KMeans private (
*/
private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = {
// Sample all the cluster centers in one pass to avoid repeated scans
val sample = data.takeSample(true, runs * k, new Random().nextInt()).toSeq
val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray)
}
......@@ -210,7 +211,7 @@ class KMeans private (
*/
private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = {
// Initialize each run's center to a random point
val seed = new Random().nextInt()
val seed = new XORShiftRandom().nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r)))
......@@ -222,7 +223,7 @@ class KMeans private (
for (r <- 0 until runs) yield (r, KMeans.pointCost(centerArrays(r), point))
}.reduceByKey(_ + _).collectAsMap()
val chosen = data.mapPartitionsWithIndex { (index, points) =>
val rand = new Random(seed ^ (step << 16) ^ index)
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
for {
p <- points
r <- 0 until runs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment