Skip to content
Snippets Groups Projects
Commit e1c814be authored by Edison Tung's avatar Edison Tung
Browse files

Renamed SparkLocalKMeans to SparkKMeans

parent a3bc012a
No related branches found
No related tags found
No related merge requests found
package spark.examples
import java.util.Random
import Vector._
import spark.SparkContext
import spark.SparkContext._
import spark.examples.Vector._
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
object SparkKMeans {
def parseVector(line: String): Vector = {
return new Vector(line.split(' ').map(_.toDouble))
}
val R = 1000 // Scaling factor
val rand = new Random(42)
def parseVector(line: String): Vector = {
return new Vector(line.split(' ').map(_.toDouble))
}
def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = {
var index = 0
var bestIndex = 0
var closest = Double.PositiveInfinity
for (i <- 1 to centers.size) {
val vCurr = centers.get(i).get
val tempDist = p.squaredDist(vCurr)
if (tempDist < closest) {
closest = tempDist
bestIndex = i
}
}
return bestIndex
}
def closestCenter(p: Vector, centers: Array[Vector]): Int = {
var bestIndex = 0
var bestDist = p.squaredDist(centers(0))
for (i <- 1 until centers.length) {
val dist = p.squaredDist(centers(i))
if (dist < bestDist) {
bestDist = dist
bestIndex = i
}
}
return bestIndex
}
def main(args: Array[String]) {
if (args.length < 4) {
System.err.println("Usage: SparkLocalKMeans <master> <file> <k> <convergeDist>")
System.exit(1)
}
val sc = new SparkContext(args(0), "SparkLocalKMeans")
val lines = sc.textFile(args(1))
val data = lines.map(parseVector _).cache()
val K = args(2).toInt
val convergeDist = args(3).toDouble
var points = data.takeSample(false, K, 42)
var kPoints = new HashMap[Int, Vector]
var tempDist = 1.0
for (i <- 1 to points.size) {
kPoints.put(i, points(i-1))
}
def main(args: Array[String]) {
if (args.length < 3) {
System.err.println("Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>")
System.exit(1)
}
val sc = new SparkContext(args(0), "SparkKMeans")
val lines = sc.textFile(args(1))
val points = lines.map(parseVector _).cache()
val dimensions = args(2).toInt
val k = args(3).toInt
val iterations = args(4).toInt
while(tempDist > convergeDist) {
var closest = data.map (p => (closestPoint(p, kPoints), (p, 1)))
var pointStats = closest.reduceByKey {case ((x1, y1), (x2, y2)) => (x1 + x2, y1+y2)}
var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)}.collect()
tempDist = 0.0
for (mapping <- newPoints) {
tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2)
}
for (newP <- newPoints) {
kPoints.put(newP._1, newP._2)
}
}
// Initialize cluster centers randomly
val rand = new Random(42)
var centers = new Array[Vector](k)
for (i <- 0 until k)
centers(i) = Vector(dimensions, _ => 2 * rand.nextDouble - 1)
println("Initial centers: " + centers.mkString(", "))
for (i <- 1 to iterations) {
println("On iteration " + i)
// Map each point to the index of its closest center and a (point, 1) pair
// that we will use to compute an average later
val mappedPoints = points.map { p => (closestCenter(p, centers), (p, 1)) }
// Compute the new centers by summing the (point, 1) pairs and taking an average
val newCenters = mappedPoints.reduceByKey {
case ((sum1, count1), (sum2, count2)) => (sum1 + sum2, count1 + count2)
}.map {
case (id, (sum, count)) => (id, sum / count)
}.collect
// Update the centers array with the new centers we collected
for ((id, value) <- newCenters) {
centers(id) = value
}
}
println("Final centers: " + centers.mkString(", "))
}
println("Final centers: " + kPoints)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment