Skip to content
Snippets Groups Projects
Commit d4c8e69d authored by Matei Zaharia's avatar Matei Zaharia
Browse files

K-means example

parent 157279e9
No related branches found
No related tags found
No related merge requests found
package spark.examples
import java.util.Random
import spark.SparkContext
import spark.SparkContext._
import spark.examples.Vector._
object SparkKMeans {
def parseVector(line: String): Vector = {
return new Vector(line.split(' ').map(_.toDouble))
}
def closestCenter(p: Vector, centers: Array[Vector]): Int = {
var bestIndex = 0
var bestDist = p.squaredDist(centers(0))
for (i <- 1 until centers.length) {
val dist = p.squaredDist(centers(i))
if (dist < bestDist) {
bestDist = dist
bestIndex = i
}
}
return bestIndex
}
def main(args: Array[String]) {
if (args.length < 3) {
System.err.println("Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>")
System.exit(1)
}
val sc = new SparkContext(args(0), "SparkKMeans")
val lines = sc.textFile(args(1))
val points = lines.map(parseVector _).cache()
val dimensions = args(2).toInt
val k = args(3).toInt
val iterations = args(4).toInt
// Initialize cluster centers randomly
val rand = new Random(42)
var centers = new Array[Vector](k)
for (i <- 0 until k)
centers(i) = Vector(dimensions, _ => 2 * rand.nextDouble - 1)
println("Initial centers: " + centers.mkString(", "))
for (i <- 1 to iterations) {
println("On iteration " + i)
// Map each point to the index of its closest center and a (point, 1) pair
// that we will use to compute an average later
val mappedPoints = points.map { p => (closestCenter(p, centers), (p, 1)) }
// Compute the new centers by summing the (point, 1) pairs and taking an average
val newCenters = mappedPoints.reduceByKey {
case ((sum1, count1), (sum2, count2)) => (sum1 + sum2, count1 + count2)
}.map {
case (id, (sum, count)) => (id, sum / count)
}.collect
// Update the centers array with the new centers we collected
for ((id, value) <- newCenters) {
centers(id) = value
}
}
println("Final centers: " + centers.mkString(", "))
}
}
......@@ -21,19 +21,35 @@ class Vector(val elements: Array[Double]) extends Serializable {
if (length != other.length)
throw new IllegalArgumentException("Vectors of different length")
var ans = 0.0
for (i <- 0 until length)
var i = 0
while (i < length) {
ans += this(i) * other(i)
i += 1
}
return ans
}
def * ( scale: Double): Vector = Vector(length, i => this(i) * scale)
def * (scale: Double): Vector = Vector(length, i => this(i) * scale)
def / (d: Double): Vector = this * (1 / d)
def unary_- = this * -1
def sum = elements.reduceLeft(_ + _)
override def toString = elements.mkString("(", ", ", ")")
def squaredDist(other: Vector): Double = {
var ans = 0.0
var i = 0
while (i < length) {
ans += (this(i) - other(i)) * (this(i) - other(i))
i += 1
}
return ans
}
def dist(other: Vector): Double = math.sqrt(squaredDist(other))
override def toString = elements.mkString("(", ", ", ")")
}
object Vector {
......
0.1 0.2 0.0 0.2
0.2 0.2 0.3 0.2
0.3 0.0 0.0 0.1
0.1 0.2 0.3 0.2
1.1 0.2 0.0 0.2
1.2 0.2 0.3 0.2
1.3 0.0 0.0 0.1
1.1 0.2 0.3 0.2
0.1 1.2 1.0 0.2
0.2 1.2 1.3 0.2
0.3 1.0 1.0 0.1
0.1 1.2 1.3 0.2
0.1 0.2 0.0 1.2
0.2 0.2 0.3 1.2
0.3 0.0 0.0 1.1
0.1 0.2 0.3 1.2
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment