Skip to content
Snippets Groups Projects
Commit 0eaf01c5 authored by Reynold Xin's avatar Reynold Xin
Browse files

Merge pull request #369 from pillis/master

SPARK-961 Add a Vector.random() method

Added method and testcases
parents 7cef8435 8d021b42
No related branches found
No related tags found
No related merge requests found
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
package org.apache.spark.util package org.apache.spark.util
import scala.util.Random
class Vector(val elements: Array[Double]) extends Serializable { class Vector(val elements: Array[Double]) extends Serializable {
def length = elements.length def length = elements.length
...@@ -124,6 +126,12 @@ object Vector { ...@@ -124,6 +126,12 @@ object Vector {
def ones(length: Int) = Vector(length, _ => 1) def ones(length: Int) = Vector(length, _ => 1)
/**
* Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers
* between 0.0 and 1.0. Optional [[scala.util.Random]] number generator can be provided.
*/
def random(length: Int, random: Random = new XORShiftRandom()) = Vector(length, _ => random.nextDouble())
class Multiplier(num: Double) { class Multiplier(num: Double) {
def * (vec: Vector) = vec * num def * (vec: Vector) = vec * num
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import scala.util.Random
import org.scalatest.FunSuite
/**
* Tests org.apache.spark.util.Vector functionality
*/
class VectorSuite extends FunSuite {
def verifyVector(vector: Vector, expectedLength: Int) = {
assert(vector.length == expectedLength)
assert(vector.elements.min > 0.0)
assert(vector.elements.max < 1.0)
}
test("random with default random number generator") {
val vector100 = Vector.random(100)
verifyVector(vector100, 100)
}
test("random with given random number generator") {
val vector100 = Vector.random(100, new Random(100))
verifyVector(vector100, 100)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment