Skip to content
Snippets Groups Projects
Commit e1c814be authored by Edison Tung's avatar Edison Tung
Browse files

Renamed SparkLocalKMeans to SparkKMeans

parent a3bc012a
No related branches found
No related tags found
No related merge requests found
package spark.examples package spark.examples
import java.util.Random import java.util.Random
import Vector._
import spark.SparkContext import spark.SparkContext
import spark.SparkContext._ import spark.SparkContext._
import spark.examples.Vector._ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
object SparkKMeans { object SparkKMeans {
def parseVector(line: String): Vector = { val R = 1000 // Scaling factor
return new Vector(line.split(' ').map(_.toDouble)) val rand = new Random(42)
}
def parseVector(line: String): Vector = {
return new Vector(line.split(' ').map(_.toDouble))
}
def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = {
var index = 0
var bestIndex = 0
var closest = Double.PositiveInfinity
for (i <- 1 to centers.size) {
val vCurr = centers.get(i).get
val tempDist = p.squaredDist(vCurr)
if (tempDist < closest) {
closest = tempDist
bestIndex = i
}
}
return bestIndex
}
def closestCenter(p: Vector, centers: Array[Vector]): Int = { def main(args: Array[String]) {
var bestIndex = 0 if (args.length < 4) {
var bestDist = p.squaredDist(centers(0)) System.err.println("Usage: SparkLocalKMeans <master> <file> <k> <convergeDist>")
for (i <- 1 until centers.length) { System.exit(1)
val dist = p.squaredDist(centers(i)) }
if (dist < bestDist) { val sc = new SparkContext(args(0), "SparkLocalKMeans")
bestDist = dist val lines = sc.textFile(args(1))
bestIndex = i val data = lines.map(parseVector _).cache()
} val K = args(2).toInt
} val convergeDist = args(3).toDouble
return bestIndex
} var points = data.takeSample(false, K, 42)
var kPoints = new HashMap[Int, Vector]
var tempDist = 1.0
for (i <- 1 to points.size) {
kPoints.put(i, points(i-1))
}
def main(args: Array[String]) { while(tempDist > convergeDist) {
if (args.length < 3) { var closest = data.map (p => (closestPoint(p, kPoints), (p, 1)))
System.err.println("Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>")
System.exit(1) var pointStats = closest.reduceByKey {case ((x1, y1), (x2, y2)) => (x1 + x2, y1+y2)}
}
val sc = new SparkContext(args(0), "SparkKMeans") var newPoints = pointStats.map {mapping => (mapping._1, mapping._2._1/mapping._2._2)}.collect()
val lines = sc.textFile(args(1))
val points = lines.map(parseVector _).cache() tempDist = 0.0
val dimensions = args(2).toInt for (mapping <- newPoints) {
val k = args(3).toInt tempDist += kPoints.get(mapping._1).get.squaredDist(mapping._2)
val iterations = args(4).toInt }
for (newP <- newPoints) {
kPoints.put(newP._1, newP._2)
}
}
// Initialize cluster centers randomly println("Final centers: " + kPoints)
val rand = new Random(42) }
var centers = new Array[Vector](k)
for (i <- 0 until k)
centers(i) = Vector(dimensions, _ => 2 * rand.nextDouble - 1)
println("Initial centers: " + centers.mkString(", "))
for (i <- 1 to iterations) {
println("On iteration " + i)
// Map each point to the index of its closest center and a (point, 1) pair
// that we will use to compute an average later
val mappedPoints = points.map { p => (closestCenter(p, centers), (p, 1)) }
// Compute the new centers by summing the (point, 1) pairs and taking an average
val newCenters = mappedPoints.reduceByKey {
case ((sum1, count1), (sum2, count2)) => (sum1 + sum2, count1 + count2)
}.map {
case (id, (sum, count)) => (id, sum / count)
}.collect
// Update the centers array with the new centers we collected
for ((id, value) <- newCenters) {
centers(id) = value
}
}
println("Final centers: " + centers.mkString(", "))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment