Skip to content
Snippets Groups Projects
Commit acf82723 authored by root's avatar root Committed by Matei Zaharia
Browse files

Fix K-means example a little

parent d0f0fc8c
No related branches found
No related tags found
No related merge requests found
...@@ -49,7 +49,7 @@ class Vector(val elements: Array[Double]) extends Serializable { ...@@ -49,7 +49,7 @@ class Vector(val elements: Array[Double]) extends Serializable {
return ans return ans
} }
def +=(other: Vector) { def += (other: Vector): Vector = {
if (length != other.length) if (length != other.length)
throw new IllegalArgumentException("Vectors of different length") throw new IllegalArgumentException("Vectors of different length")
var ans = 0.0 var ans = 0.0
...@@ -58,6 +58,7 @@ class Vector(val elements: Array[Double]) extends Serializable { ...@@ -58,6 +58,7 @@ class Vector(val elements: Array[Double]) extends Serializable {
elements(i) += other(i) elements(i) += other(i)
i += 1 i += 1
} }
this
} }
def * (scale: Double): Vector = Vector(length, i => this(i) * scale) def * (scale: Double): Vector = Vector(length, i => this(i) * scale)
......
...@@ -15,14 +15,13 @@ object SparkKMeans { ...@@ -15,14 +15,13 @@ object SparkKMeans {
return new Vector(line.split(' ').map(_.toDouble)) return new Vector(line.split(' ').map(_.toDouble))
} }
def closestPoint(p: Vector, centers: HashMap[Int, Vector]): Int = { def closestPoint(p: Vector, centers: Array[Vector]): Int = {
var index = 0 var index = 0
var bestIndex = 0 var bestIndex = 0
var closest = Double.PositiveInfinity var closest = Double.PositiveInfinity
for (i <- 1 to centers.size) { for (i <- 0 until centers.length) {
val vCurr = centers.get(i).get val tempDist = p.squaredDist(centers(i))
val tempDist = p.squaredDist(vCurr)
if (tempDist < closest) { if (tempDist < closest) {
closest = tempDist closest = tempDist
bestIndex = i bestIndex = i
...@@ -43,32 +42,28 @@ object SparkKMeans { ...@@ -43,32 +42,28 @@ object SparkKMeans {
val K = args(2).toInt val K = args(2).toInt
val convergeDist = args(3).toDouble val convergeDist = args(3).toDouble
var points = data.takeSample(false, K, 42) var kPoints = data.takeSample(false, K, 42).toArray
var kPoints = new HashMap[Int, Vector]
var tempDist = 1.0 var tempDist = 1.0
for (i <- 1 to points.size) {
kPoints.put(i, points(i-1))
}
while(tempDist > convergeDist) { while(tempDist > convergeDist) {
var closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) var closest = data.map (p => (closestPoint(p, kPoints), (p, 1)))
var pointStats = closest.reduceByKey {case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)} var pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)}
var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collect() var newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap()
tempDist = 0.0 tempDist = 0.0
for (pair <- newPoints) { for (i <- 0 until K) {
tempDist += kPoints.get(pair._1).get.squaredDist(pair._2) tempDist += kPoints(i).squaredDist(newPoints(i))
} }
for (newP <- newPoints) { for (newP <- newPoints) {
kPoints.put(newP._1, newP._2) kPoints(newP._1) = newP._2
} }
} }
println("Final centers: " + kPoints) println("Final centers:")
kPoints.foreach(println)
System.exit(0) System.exit(0)
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment