diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index a2d85a68cd32786434f7deec59491a184a18fc88..9eab7efc160dae6ee37eea245d43be51c809d01e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -17,8 +17,7 @@ package org.apache.spark.mllib.random -import org.apache.commons.math3.distribution.{ExponentialDistribution, - GammaDistribution, LogNormalDistribution, PoissonDistribution} +import org.apache.commons.math3.distribution._ import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} @@ -195,3 +194,27 @@ class LogNormalGenerator @Since("1.3.0") ( @Since("1.3.0") override def copy(): LogNormalGenerator = new LogNormalGenerator(mean, std) } + +/** + * :: DeveloperApi :: + * Generates i.i.d. samples from the Weibull distribution with the + * given shape and scale parameter. + * + * @param alpha shape parameter for the Weibull distribution. + * @param beta scale parameter for the Weibull distribution. + */ +@DeveloperApi +class WeibullGenerator( + val alpha: Double, + val beta: Double) extends RandomDataGenerator[Double] { + + private val rng = new WeibullDistribution(alpha, beta) + + override def nextValue(): Double = rng.sample() + + override def setSeed(seed: Long): Unit = { + rng.reseedRandomGenerator(seed) + } + + override def copy(): WeibullGenerator = new WeibullGenerator(alpha, beta) +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala index a5ca1518f82f5c071833eca2148bad87372dff95..8416771552fd3267ba5f086b82d343dd37833a6d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomDataGeneratorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.random -import scala.math +import org.apache.commons.math3.special.Gamma import org.apache.spark.SparkFunSuite import org.apache.spark.util.StatCounter @@ -136,4 +136,18 @@ class RandomDataGeneratorSuite extends SparkFunSuite { distributionChecks(gamma, expectedMean, expectedStd, 0.1) } } + + test("WeibullGenerator") { + List((1.0, 2.0), (2.0, 3.0), (2.5, 3.5), (10.4, 2.222)).map { + case (alpha: Double, beta: Double) => + val weibull = new WeibullGenerator(alpha, beta) + apiChecks(weibull) + + val expectedMean = math.exp(Gamma.logGamma(1 + (1 / alpha))) * beta + val expectedVariance = math.exp( + Gamma.logGamma(1 + (2 / alpha))) * beta * beta - expectedMean * expectedMean + val expectedStd = math.sqrt(expectedVariance) + distributionChecks(weibull, expectedMean, expectedStd, 0.1) + } + } }