Skip to content
Snippets Groups Projects
Commit bf1311e6 authored by shivaram's avatar shivaram
Browse files

Merge pull request #678 from mateiz/ml-examples

Start of ML package
parents 6ad85d09 8bbe9075
No related branches found
No related tags found
No related merge requests found
package spark.mllib.util
import spark.{RDD, SparkContext}
import spark.SparkContext._
import org.jblas.DoubleMatrix
/**
* Helper methods to load and save data
* Data format:
* <l>, <f1> <f2> ...
* where <f1>, <f2> are feature values in Double and <l> is the corresponding label as Double.
*/
object MLUtils {
/**
* @param sc SparkContext
* @param dir Directory to the input data files.
* @return An RDD of tuples. For each tuple, the first element is the label, and the second
* element represents the feature values (an array of Double).
*/
def loadData(sc: SparkContext, dir: String): RDD[(Double, Array[Double])] = {
sc.textFile(dir).map { line =>
val parts = line.split(",")
val label = parts(0).toDouble
val features = parts(1).trim().split(" ").map(_.toDouble)
(label, features)
}
}
def saveData(data: RDD[(Double, Array[Double])], dir: String) {
val dataStr = data.map(x => x._1 + "," + x._2.mkString(" "))
dataStr.saveAsTextFile(dir)
}
/**
* Utility function to compute mean and standard deviation on a given dataset.
*
* @param data - input data set whose statistics are computed
* @param nfeatures - number of features
* @param nexamples - number of examples in input dataset
*
* @return (yMean, xColMean, xColSd) - Tuple consisting of
* yMean - mean of the labels
* xColMean - Row vector with mean for every column (or feature) of the input data
* xColSd - Row vector standard deviation for every column (or feature) of the input data.
*/
def computeStats(data: RDD[(Double, Array[Double])], nfeatures: Int, nexamples: Long):
(Double, DoubleMatrix, DoubleMatrix) = {
val yMean: Double = data.map { case (y, features) => y }.reduce(_ + _) / nexamples
// NOTE: We shuffle X by column here to compute column sum and sum of squares.
val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { case(y, features) =>
val nCols = features.length
// Traverse over every column and emit (col, value, value^2)
Iterator.tabulate(nCols) { i =>
(i, (features(i), features(i)*features(i)))
}
}.reduceByKey { case(x1, x2) =>
(x1._1 + x2._1, x1._2 + x2._2)
}
val xColSumsMap = xColSumSq.collectAsMap()
val xColMean = DoubleMatrix.zeros(nfeatures, 1)
val xColSd = DoubleMatrix.zeros(nfeatures, 1)
// Compute mean and unbiased variance using column sums
var col = 0
while (col < nfeatures) {
xColMean.put(col, xColSumsMap(col)._1 / nexamples)
val variance =
(xColSumsMap(col)._2 - (math.pow(xColSumsMap(col)._1, 2) / nexamples)) / (nexamples)
xColSd.put(col, math.sqrt(variance))
col += 1
}
(yMean, xColMean, xColSd)
}
/**
* Return the squared Euclidean distance between two vectors.
*/
def squaredDistance(v1: Array[Double], v2: Array[Double]): Double = {
if (v1.length != v2.length) {
throw new IllegalArgumentException("Vector sizes don't match")
}
var i = 0
var sum = 0.0
while (i < v1.length) {
sum += (v1(i) - v2(i)) * (v1(i) - v2(i))
i += 1
}
sum
}
}
# Set everything to be logged to the file core/target/unit-tests.log
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=ml/target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
package spark.mllib.clustering
import scala.util.Random
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import spark.SparkContext
import spark.SparkContext._
import org.jblas._
class KMeansSuite extends FunSuite with BeforeAndAfterAll {
val sc = new SparkContext("local", "test")
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
val EPSILON = 1e-4
import KMeans.{RANDOM, K_MEANS_PARALLEL}
def prettyPrint(point: Array[Double]): String = point.mkString("(", ", ", ")")
def prettyPrint(points: Array[Array[Double]]): String = {
points.map(prettyPrint).mkString("(", "; ", ")")
}
// L1 distance between two points
def distance1(v1: Array[Double], v2: Array[Double]): Double = {
v1.zip(v2).map{ case (a, b) => math.abs(a-b) }.max
}
// Assert that two vectors are equal within tolerance EPSILON
def assertEqual(v1: Array[Double], v2: Array[Double]) {
def errorMessage = prettyPrint(v1) + " did not equal " + prettyPrint(v2)
assert(v1.length == v2.length, errorMessage)
assert(distance1(v1, v2) <= EPSILON, errorMessage)
}
// Assert that two sets of points are equal, within EPSILON tolerance
def assertSetsEqual(set1: Array[Array[Double]], set2: Array[Array[Double]]) {
def errorMessage = prettyPrint(set1) + " did not equal " + prettyPrint(set2)
assert(set1.length == set2.length, errorMessage)
for (v <- set1) {
val closestDistance = set2.map(w => distance1(v, w)).min
if (closestDistance > EPSILON) {
fail(errorMessage)
}
}
for (v <- set2) {
val closestDistance = set1.map(w => distance1(v, w)).min
if (closestDistance > EPSILON) {
fail(errorMessage)
}
}
}
test("single cluster") {
val data = sc.parallelize(Array(
Array(1.0, 2.0, 6.0),
Array(1.0, 3.0, 0.0),
Array(1.0, 4.0, 6.0)
))
// No matter how many runs or iterations we use, we should get one cluster,
// centered at the mean of the points
var model = KMeans.train(data, k=1, maxIterations=1)
assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
model = KMeans.train(data, k=1, maxIterations=2)
assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
model = KMeans.train(data, k=1, maxIterations=5)
assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
model = KMeans.train(data, k=1, maxIterations=1, runs=5)
assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
model = KMeans.train(data, k=1, maxIterations=1, runs=5)
assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
model = KMeans.train(
data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
}
test("single cluster with big dataset") {
val smallData = Array(
Array(1.0, 2.0, 6.0),
Array(1.0, 3.0, 0.0),
Array(1.0, 4.0, 6.0)
)
val data = sc.parallelize((1 to 100).flatMap(_ => smallData), 4)
// No matter how many runs or iterations we use, we should get one cluster,
// centered at the mean of the points
var model = KMeans.train(data, k=1, maxIterations=1)
assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
model = KMeans.train(data, k=1, maxIterations=2)
assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
model = KMeans.train(data, k=1, maxIterations=5)
assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
model = KMeans.train(data, k=1, maxIterations=1, runs=5)
assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
model = KMeans.train(data, k=1, maxIterations=1, runs=5)
assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
}
test("k-means|| initialization") {
val points = Array(
Array(1.0, 2.0, 6.0),
Array(1.0, 3.0, 0.0),
Array(1.0, 4.0, 6.0),
Array(1.0, 0.0, 1.0),
Array(1.0, 1.0, 1.0)
)
val rdd = sc.parallelize(points)
// K-means|| initialization should place all clusters into distinct centers because
// it will make at least five passes, and it will give non-zero probability to each
// unselected point as long as it hasn't yet selected all of them
var model = KMeans.train(rdd, k=5, maxIterations=1)
assertSetsEqual(model.clusterCenters, points)
// Iterations of Lloyd's should not change the answer either
model = KMeans.train(rdd, k=5, maxIterations=10)
assertSetsEqual(model.clusterCenters, points)
// Neither should more runs
model = KMeans.train(rdd, k=5, maxIterations=10, runs=5)
assertSetsEqual(model.clusterCenters, points)
}
}
package spark.mllib.recommendation
import scala.util.Random
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import spark.SparkContext
import spark.SparkContext._
import org.jblas._
class ALSSuite extends FunSuite with BeforeAndAfterAll {
val sc = new SparkContext("local", "test")
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
test("rank-1 matrices") {
testALS(10, 20, 1, 15, 0.7, 0.3)
}
test("rank-2 matrices") {
testALS(20, 30, 2, 15, 0.7, 0.3)
}
/**
* Test if we can correctly factorize R = U * P where U and P are of known rank.
*
* @param users number of users
* @param products number of products
* @param features number of features (rank of problem)
* @param iterations number of iterations to run
* @param samplingRate what fraction of the user-product pairs are known
* @param matchThreshold max difference allowed to consider a predicted rating correct
*/
def testALS(users: Int, products: Int, features: Int, iterations: Int,
samplingRate: Double, matchThreshold: Double)
{
val rand = new Random(42)
// Create a random matrix with uniform values from -1 to 1
def randomMatrix(m: Int, n: Int) =
new DoubleMatrix(m, n, Array.fill(m * n)(rand.nextDouble() * 2 - 1): _*)
val userMatrix = randomMatrix(users, features)
val productMatrix = randomMatrix(features, products)
val trueRatings = userMatrix.mmul(productMatrix)
val sampledRatings = {
for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate)
yield (u, p, trueRatings.get(u, p))
}
val model = ALS.train(sc.parallelize(sampledRatings), features, iterations)
val predictedU = new DoubleMatrix(users, features)
for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) {
predictedU.put(u, i, vec(i))
}
val predictedP = new DoubleMatrix(products, features)
for ((p, vec) <- model.productFeatures.collect(); i <- 0 until features) {
predictedP.put(p, i, vec(i))
}
val predictedRatings = predictedU.mmul(predictedP.transpose)
for (u <- 0 until users; p <- 0 until products) {
val prediction = predictedRatings.get(u, p)
val correct = trueRatings.get(u, p)
if (math.abs(prediction - correct) > matchThreshold) {
fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP))
}
}
}
}
package spark.mllib.regression
import scala.util.Random
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import spark.SparkContext
import spark.SparkContext._
class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
val sc = new SparkContext("local", "test")
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
// Test if we can correctly learn A, B where Y = logistic(A + B*X)
test("logistic regression") {
val nPoints = 10000
val rnd = new Random(42)
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
val A = 2.0
val B = -1.5
// NOTE: if U is uniform[0, 1] then ln(u) - ln(1-u) is Logistic(0,1)
val unifRand = new scala.util.Random(45)
val rLogis = (0 until nPoints).map { i =>
val u = unifRand.nextDouble()
math.log(u) - math.log(1.0-u)
}
// y <- A + B*x + rlogis(100)
// y <- as.numeric(y > 0)
val y = (0 until nPoints).map { i =>
val yVal = A + B * x1(i) + rLogis(i)
if (yVal > 0) 1.0 else 0.0
}
val testData = (0 until nPoints).map(i => (y(i).toDouble, Array(x1(i)))).toArray
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
val lr = new LogisticRegression().setStepSize(10.0)
.setNumIterations(20)
val model = lr.train(testRDD)
val weight0 = model.weights.get(0)
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
}
}
package spark.mllib.regression
import scala.util.Random
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import spark.SparkContext
import spark.SparkContext._
class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll {
val sc = new SparkContext("local", "test")
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
// Test if we can correctly learn Y = 3 + X1 + X2 when
// X1 and X2 are collinear.
test("multi-collinear variables") {
val rnd = new Random(43)
val x1 = Array.fill[Double](20)(rnd.nextGaussian())
// Pick a mean close to mean of x1
val rnd1 = new Random(42) //new NormalDistribution(0.1, 0.01)
val x2 = Array.fill[Double](20)(0.1 + rnd1.nextGaussian() * 0.01)
val xMat = (0 until 20).map(i => Array(x1(i), x2(i))).toArray
val y = xMat.map(i => 3 + i(0) + i(1))
val testData = (0 until 20).map(i => (y(i), xMat(i))).toArray
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
val ridgeReg = new RidgeRegression().setLowLambda(0)
.setHighLambda(10)
val model = ridgeReg.train(testRDD)
assert(model.intercept >= 2.9 && model.intercept <= 3.1)
assert(model.weights.length === 2)
assert(model.weights.get(0) >= 0.9 && model.weights.get(0) <= 1.1)
assert(model.weights.get(1) >= 0.9 && model.weights.get(1) <= 1.1)
}
}
......@@ -25,7 +25,7 @@ object SparkBuild extends Build {
//val HADOOP_MAJOR_VERSION = "2"
//val HADOOP_YARN = true
lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming)
lazy val root = Project("root", file("."), settings = rootSettings) aggregate(core, repl, examples, bagel, streaming, mllib)
lazy val core = Project("core", file("core"), settings = coreSettings)
......@@ -37,6 +37,8 @@ object SparkBuild extends Build {
lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn (core)
lazy val mllib = Project("mllib", file("mllib"), settings = mllibSettings) dependsOn (core)
// A configuration to set an alternative publishLocalConfiguration
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
......@@ -219,6 +221,13 @@ object SparkBuild extends Build {
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")
def mllibSettings = sharedSettings ++ Seq(
name := "spark-mllib",
libraryDependencies ++= Seq(
"org.jblas" % "jblas" % "1.2.3"
)
)
def streamingSettings = sharedSettings ++ Seq(
name := "spark-streaming",
resolvers ++= Seq(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment