Skip to content
Snippets Groups Projects
Commit 09bdfe3b authored by Marek Kolodziej's avatar Marek Kolodziej
Browse files

XORShift RNG with unit tests and benchmark

To run unit test, start SBT console and type:
compile
test-only org.apache.spark.util.XORShiftRandomSuite
To run benchmark, type:
project core
console
Once the Scala console starts, type:
org.apache.spark.util.XORShiftRandom.benchmark(100000000)
parent e2ebc3a9
No related branches found
No related tags found
No related merge requests found
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
package org.apache.spark.rdd package org.apache.spark.rdd
import java.util.Random import org.apache.spark.util.{XORShiftRandom => Random}
import scala.collection.Map import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.JavaConversions.mapAsScalaMap
......
...@@ -818,9 +818,42 @@ private[spark] object Utils extends Logging { ...@@ -818,9 +818,42 @@ private[spark] object Utils extends Logging {
hashAbs hashAbs
} }
/** Returns a copy of the system properties that is thread-safe to iterator over. */ /* Returns a copy of the system properties that is thread-safe to iterator over. */
def getSystemProperties(): Map[String, String] = { def getSystemProperties(): Map[String, String] = {
return System.getProperties().clone() return System.getProperties().clone()
.asInstanceOf[java.util.Properties].toMap[String, String] .asInstanceOf[java.util.Properties].toMap[String, String]
} }
/* Used for performance tersting along with the intToTimesInt() and timeIt methods
* It uses a while loop instead of a for comprehension since the JIT will
* optimize the while loop better than the "for" closure
* e.g.
* import org.apache.spark.util.Utils.{TimesInt, intToTimesInt, timeIt}
* import java.util.Random
* val rand = new Random()
* timeIt(rand.nextDouble, 10000000)
*/
class TimesInt(i: Int) {
def times(f: => Unit) = {
var x = 1
while (x <= i) {
f
x += 1
}
}
}
/* Used in conjunction with TimesInt since it's Scala 2.9.3
* instead of 2.10 and we don't have implicit classes */
implicit def intToTimesInt(i: Int) = new TimesInt(i)
/* See TimesInt for use example */
def timeIt(f: => Unit, iters: Int): Long = {
val start = System.currentTimeMillis
iters.times(f)
System.currentTimeMillis - start
}
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import java.util.{Random => JavaRandom}
import Utils.{TimesInt, intToTimesInt, timeIt}
class XORShiftRandom(init: Long) extends JavaRandom(init) {
def this() = this(System.nanoTime)
var seed = init
// we need to just override next - this will be called by nextInt, nextDouble,
// nextGaussian, nextLong, etc.
override protected def next(bits: Int): Int = {
var nextSeed = seed ^ (seed << 21)
nextSeed ^= (nextSeed >>> 35)
nextSeed ^= (nextSeed << 4)
seed = nextSeed
(nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
}
}
object XORShiftRandom {
def benchmark(numIters: Int) = {
val seed = 1L
val million = 1e6.toInt
val javaRand = new JavaRandom(seed)
val xorRand = new XORShiftRandom(seed)
// warm up the JIT
million.times {
javaRand.nextInt
xorRand.nextInt
}
/* Return results as a map instead of just printing to screen
in case the user wants to do something with them */
Map("javaTime" -> timeIt(javaRand.nextInt, numIters),
"xorTime" -> timeIt(xorRand.nextInt, numIters))
}
}
\ No newline at end of file
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import java.util.Random
import org.scalatest.FlatSpec
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.util.Utils.{TimesInt, intToTimesInt, timeIt}
class XORShiftRandomSuite extends FunSuite with ShouldMatchers {
def fixture = new {
val seed = 1L
val xorRand = new XORShiftRandom(seed)
val hundMil = 1e8.toInt
}
/*
* This test is based on a chi-squared test for randomness. The values are hard-coded
* so as not to create Spark's dependency on apache.commons.math3 just to call one
* method for calculating the exact p-value for a given number of random numbers
* and bins. In case one would want to move to a full-fledged test based on
* apache.commons.math3, the relevant class is here:
* org.apache.commons.math3.stat.inference.ChiSquareTest
*/
test ("XORShift generates valid random numbers") {
val f = fixture
val numBins = 10
// create 10 bins
val bins = Array.fill(numBins)(0)
// populate bins based on modulus of the random number
f.hundMil.times(bins(math.abs(f.xorRand.nextInt) % 10) += 1)
/* since the seed is deterministic, until the algorithm is changed, we know the result will be
* exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
* 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%)
* significance level. However, should the RNG implementation change, the test should still
* pass at the same significance level. The chi-squared test done in R gave the following
* results:
* > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
* 10000790, 10002286, 9998699))
* Chi-squared test for given probabilities
* data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790,
* 10002286, 9998699)
* X-squared = 11.975, df = 9, p-value = 0.2147
* Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million
* random numbers
* and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared
* is greater than or equal to that number.
*/
val binSize = f.hundMil/numBins
val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum
xSquared should be < (16.9196)
}
}
\ No newline at end of file
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
package org.apache.spark.mllib.clustering package org.apache.spark.mllib.clustering
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.util.Random import org.apache.spark.util.{XORShiftRandom => Random}
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._ import org.apache.spark.SparkContext._
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment