From 80c29689ae3b589254a571da3ddb5f9c866ae534 Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Sun, 23 Mar 2014 17:34:02 -0700
Subject: [PATCH] [SPARK-1212] Adding sparse data support and update KMeans

Continue our discussions from https://github.com/apache/incubator-spark/pull/575

This PR is WIP because it depends on a SNAPSHOT version of breeze.

Per previous discussions and benchmarks, I switched to breeze for linear algebra operations. @dlwh and I made some improvements to breeze to keep its performance comparable to the bare-bone implementation, including norm computation and squared distance. This is why this PR needs to depend on a SNAPSHOT version of breeze.

@fommil , please find the notice of using netlib-core in `NOTICE`. This is following Apache's instructions on appropriate labeling.

I'm going to update this PR to include:

1. Fast distance computation: using `\|a\|_2^2 + \|b\|_2^2 - 2 a^T b` when it doesn't introduce too much numerical error. The squared norms are pre-computed. Otherwise, computing the distance between the center (dense) and a point (possibly sparse) always takes O(n) time.

2. Some numbers about the performance.

3. A released version of breeze. @dlwh, a minor release of breeze will help this PR get merged early. Do you mind sharing breeze's release plan? Thanks!

Author: Xiangrui Meng <meng@databricks.com>

Closes #117 from mengxr/sparse-kmeans and squashes the following commits:

67b368d [Xiangrui Meng] fix SparseVector.toArray
5eda0de [Xiangrui Meng] update NOTICE
67abe31 [Xiangrui Meng] move ArrayRDDs to mllib.rdd
1da1033 [Xiangrui Meng] remove dependency on commons-math3 and compute EPSILON directly
9bb1b31 [Xiangrui Meng] optimize SparseVector.toArray
226d2cd [Xiangrui Meng] update Java friendly methods in Vectors
238ba34 [Xiangrui Meng] add VectorRDDs with a converter from RDD[Array[Double]]
b28ba2f [Xiangrui Meng] add toArray to Vector
e69b10c [Xiangrui Meng] remove examples/JavaKMeans.java, which is replaced by mllib/examples/JavaKMeans.java
72bde33 [Xiangrui Meng] clean up code for distance computation
712cb88 [Xiangrui Meng] make Vectors.sparse Java friendly
27858e4 [Xiangrui Meng] update breeze version to 0.7
07c3cf2 [Xiangrui Meng] change Mahout to breeze in doc use a simple lower bound to avoid unnecessary distance computation
6f5cdde [Xiangrui Meng] fix a bug in filtering finished runs
42512f2 [Xiangrui Meng] Merge branch 'master' into sparse-kmeans
d6e6c07 [Xiangrui Meng] add predict(RDD[Vector]) to KMeansModel
42b4e50 [Xiangrui Meng] line feed at the end
a4ace73 [Xiangrui Meng] Merge branch 'fast-dist' into sparse-kmeans
3ed1a24 [Xiangrui Meng] add doc to BreezeVectorWithSquaredNorm
0107e19 [Xiangrui Meng] update NOTICE
87bc755 [Xiangrui Meng] tuned the KMeans code: changed some for loops to while, use view to avoid copying arrays
0ff8046 [Xiangrui Meng] update KMeans to use fastSquaredDistance
f355411 [Xiangrui Meng] add BreezeVectorWithSquaredNorm case class
ab74f67 [Xiangrui Meng] add fastSquaredDistance for KMeans
4e7d5ca [Xiangrui Meng] minor style update
07ffaf2 [Xiangrui Meng] add dense/sparse vector data models and conversions to/from breeze vectors use breeze to implement KMeans in order to support both dense and sparse data
---
 NOTICE                                        |   9 +
 .../org/apache/spark/examples/JavaKMeans.java | 138 -----------
 .../spark/mllib/examples/JavaKMeans.java      |  23 +-
 mllib/pom.xml                                 |   5 +
 .../mllib/api/python/PythonMLLibAPI.scala     |  12 +-
 .../spark/mllib/clustering/KMeans.scala       | 233 ++++++++++++------
 .../spark/mllib/clustering/KMeansModel.scala  |  24 +-
 .../spark/mllib/clustering/LocalKMeans.scala  |  58 +++--
 .../apache/spark/mllib/linalg/Vectors.scala   | 177 +++++++++++++
 .../apache/spark/mllib/rdd/VectorRDDs.scala   |  32 +++
 .../org/apache/spark/mllib/util/MLUtils.scala |  61 ++++-
 .../mllib/clustering/JavaKMeansSuite.java     |  88 +++----
 .../spark/mllib/linalg/JavaVectorsSuite.java  |  44 ++++
 .../spark/mllib/clustering/KMeansSuite.scala  | 175 +++++++------
 .../linalg/BreezeVectorConversionSuite.scala  |  58 +++++
 .../spark/mllib/linalg/VectorsSuite.scala     |  85 +++++++
 .../spark/mllib/rdd/VectorRDDsSuite.scala     |  33 +++
 .../spark/mllib/util/LocalSparkContext.scala  |  17 ++
 .../spark/mllib/util/MLUtilsSuite.scala       |  52 ++++
 project/SparkBuild.scala                      |   3 +-
 20 files changed, 930 insertions(+), 397 deletions(-)
 delete mode 100644 examples/src/main/java/org/apache/spark/examples/JavaKMeans.java
 create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
 create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDs.scala
 create mode 100644 mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
 create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
 create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
 create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDsSuite.scala
 create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala

diff --git a/NOTICE b/NOTICE
index dce0c4eaf3..42f6c3a835 100644
--- a/NOTICE
+++ b/NOTICE
@@ -3,3 +3,12 @@ Copyright 2014 The Apache Software Foundation.
 
 This product includes software developed at
 The Apache Software Foundation (http://www.apache.org/).
+
+In addition, this product includes:
+
+- JUnit (http://www.junit.org) is a testing framework for Java. We included it
+  under the terms of the Eclipse Public License v1.0.
+
+- JTransforms (https://sites.google.com/site/piotrwendykier/software/jtransforms)
+  provides fast transforms in Java. It is tri-licensed, and we included it under 
+  the terms of the Mozilla Public License v1.1.
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java
deleted file mode 100644
index 2d797279d5..0000000000
--- a/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples;
-
-import scala.Tuple2;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.api.java.function.PairFunction;
-import org.apache.spark.util.Vector;
-
-import java.util.List;
-import java.util.Map;
-import java.util.regex.Pattern;
-
-/**
- * K-means clustering using Java API.
- */
-public final class JavaKMeans {
-
-  private static final Pattern SPACE = Pattern.compile(" ");
-
-  /** Parses numbers split by whitespace to a vector */
-  static Vector parseVector(String line) {
-    String[] splits = SPACE.split(line);
-    double[] data = new double[splits.length];
-    int i = 0;
-    for (String s : splits) {
-      data[i] = Double.parseDouble(s);
-      i++;
-    }
-    return new Vector(data);
-  }
-
-  /** Computes the vector to which the input vector is closest using squared distance */
-  static int closestPoint(Vector p, List<Vector> centers) {
-    int bestIndex = 0;
-    double closest = Double.POSITIVE_INFINITY;
-    for (int i = 0; i < centers.size(); i++) {
-      double tempDist = p.squaredDist(centers.get(i));
-      if (tempDist < closest) {
-        closest = tempDist;
-        bestIndex = i;
-      }
-    }
-    return bestIndex;
-  }
-
-  /** Computes the mean across all vectors in the input set of vectors */
-  static Vector average(List<Vector> ps) {
-    int numVectors = ps.size();
-    Vector out = new Vector(ps.get(0).elements());
-    // start from i = 1 since we already copied index 0 above
-    for (int i = 1; i < numVectors; i++) {
-      out.addInPlace(ps.get(i));
-    }
-    return out.divide(numVectors);
-  }
-
-  public static void main(String[] args) throws Exception {
-    if (args.length < 4) {
-      System.err.println("Usage: JavaKMeans <master> <file> <k> <convergeDist>");
-      System.exit(1);
-    }
-    JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans",
-      System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaKMeans.class));
-    String path = args[1];
-    int K = Integer.parseInt(args[2]);
-    double convergeDist = Double.parseDouble(args[3]);
-
-    JavaRDD<Vector> data = sc.textFile(path).map(
-      new Function<String, Vector>() {
-        @Override
-        public Vector call(String line) {
-          return parseVector(line);
-        }
-      }
-    ).cache();
-
-    final List<Vector> centroids = data.takeSample(false, K, 42);
-
-    double tempDist;
-    do {
-      // allocate each vector to closest centroid
-      JavaPairRDD<Integer, Vector> closest = data.mapToPair(
-        new PairFunction<Vector, Integer, Vector>() {
-          @Override
-          public Tuple2<Integer, Vector> call(Vector vector) {
-            return new Tuple2<Integer, Vector>(
-              closestPoint(vector, centroids), vector);
-          }
-        }
-      );
-
-      // group by cluster id and average the vectors within each cluster to compute centroids
-      JavaPairRDD<Integer, List<Vector>> pointsGroup = closest.groupByKey();
-      Map<Integer, Vector> newCentroids = pointsGroup.mapValues(
-        new Function<List<Vector>, Vector>() {
-          @Override
-          public Vector call(List<Vector> ps) {
-            return average(ps);
-          }
-        }).collectAsMap();
-      tempDist = 0.0;
-      for (int i = 0; i < K; i++) {
-        tempDist += centroids.get(i).squaredDist(newCentroids.get(i));
-      }
-      for (Map.Entry<Integer, Vector> t: newCentroids.entrySet()) {
-        centroids.set(t.getKey(), t.getValue());
-      }
-      System.out.println("Finished iteration (delta = " + tempDist + ")");
-    } while (tempDist > convergeDist);
-
-    System.out.println("Final centers:");
-    for (Vector c : centroids) {
-      System.out.println(c);
-    }
-
-    System.exit(0);
-
-  }
-}
diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java
index 76ebdccfd6..7b0ec36424 100644
--- a/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java
+++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java
@@ -17,32 +17,33 @@
 
 package org.apache.spark.mllib.examples;
 
+import java.util.regex.Pattern;
+
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
 
 import org.apache.spark.mllib.clustering.KMeans;
 import org.apache.spark.mllib.clustering.KMeansModel;
-
-import java.util.Arrays;
-import java.util.regex.Pattern;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
 
 /**
  * Example using MLLib KMeans from Java.
  */
 public final class JavaKMeans {
 
-  static class ParsePoint implements Function<String, double[]> {
+  private static class ParsePoint implements Function<String, Vector> {
     private static final Pattern SPACE = Pattern.compile(" ");
 
     @Override
-    public double[] call(String line) {
+    public Vector call(String line) {
       String[] tok = SPACE.split(line);
       double[] point = new double[tok.length];
       for (int i = 0; i < tok.length; ++i) {
         point[i] = Double.parseDouble(tok[i]);
       }
-      return point;
+      return Vectors.dense(point);
     }
   }
 
@@ -65,15 +66,15 @@ public final class JavaKMeans {
 
     JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans",
         System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaKMeans.class));
-    JavaRDD<String> lines = sc.textFile(args[1]);
+    JavaRDD<String> lines = sc.textFile(inputFile);
 
-    JavaRDD<double[]> points = lines.map(new ParsePoint());
+    JavaRDD<Vector> points = lines.map(new ParsePoint());
 
-    KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs);
+    KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs, KMeans.K_MEANS_PARALLEL());
 
     System.out.println("Cluster centers:");
-    for (double[] center : model.clusterCenters()) {
-      System.out.println(" " + Arrays.toString(center));
+    for (Vector center : model.clusterCenters()) {
+      System.out.println(" " + center);
     }
     double cost = model.computeCost(points.rdd());
     System.out.println("Cost: " + cost);
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 9b65cb4b4c..fec1cc94b2 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -60,6 +60,11 @@
       <artifactId>jblas</artifactId>
       <version>1.2.3</version>
     </dependency>
+    <dependency>
+      <groupId>org.scalanlp</groupId>
+      <artifactId>breeze_${scala.binary.version}</artifactId>
+      <version>0.7</version>
+    </dependency>
     <dependency>
       <groupId>org.scalatest</groupId>
       <artifactId>scalatest_${scala.binary.version}</artifactId>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index efe99a31be..3449c698da 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -16,14 +16,16 @@
  */
 
 package org.apache.spark.mllib.api.python
+
+import java.nio.{ByteBuffer, ByteOrder}
+
 import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.mllib.regression._
 import org.apache.spark.mllib.classification._
 import org.apache.spark.mllib.clustering._
+import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.recommendation._
+import org.apache.spark.mllib.regression._
 import org.apache.spark.rdd.RDD
-import java.nio.ByteBuffer
-import java.nio.ByteOrder
 
 /**
  * The Java stubs necessary for the Python mllib bindings.
@@ -205,10 +207,10 @@ class PythonMLLibAPI extends Serializable {
   def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int,
       maxIterations: Int, runs: Int, initializationMode: String):
       java.util.List[java.lang.Object] = {
-    val data = dataBytesJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes))
+    val data = dataBytesJRDD.rdd.map(xBytes => Vectors.dense(deserializeDoubleVector(xBytes)))
     val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
     val ret = new java.util.LinkedList[java.lang.Object]()
-    ret.add(serializeDoubleMatrix(model.clusterCenters))
+    ret.add(serializeDoubleMatrix(model.clusterCenters.map(_.toArray)))
     ret
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index e508b76c3f..b412738e3f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -19,16 +19,15 @@ package org.apache.spark.mllib.clustering
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.jblas.DoubleMatrix
+import breeze.linalg.{DenseVector => BDV, Vector => BV, norm => breezeNorm}
 
-import org.apache.spark.SparkContext
+import org.apache.spark.{Logging, SparkContext}
 import org.apache.spark.SparkContext._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.Logging
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
 import org.apache.spark.util.random.XORShiftRandom
 
-
 /**
  * K-means clustering with support for multiple parallel runs and a k-means++ like initialization
  * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
@@ -44,10 +43,7 @@ class KMeans private (
     var initializationMode: String,
     var initializationSteps: Int,
     var epsilon: Double)
-  extends Serializable with Logging
-{
-  private type ClusterCenters = Array[Array[Double]]
-
+  extends Serializable with Logging {
   def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4)
 
   /** Set the number of clusters to create (k). Default: 2. */
@@ -113,28 +109,50 @@ class KMeans private (
    * Train a K-means model on the given set of points; `data` should be cached for high
    * performance, because this is an iterative algorithm.
    */
-  def run(data: RDD[Array[Double]]): KMeansModel = {
-    // TODO: check whether data is persistent; this needs RDD.storageLevel to be publicly readable
+  def run(data: RDD[Vector]): KMeansModel = {
+    // Compute squared norms and cache them.
+    val norms = data.map(v => breezeNorm(v.toBreeze, 2.0))
+    norms.persist()
+    val breezeData = data.map(_.toBreeze).zip(norms).map { case (v, norm) =>
+      new BreezeVectorWithNorm(v, norm)
+    }
+    val model = runBreeze(breezeData)
+    norms.unpersist()
+    model
+  }
+
+  /**
+   * Implementation of K-Means using breeze.
+   */
+  private def runBreeze(data: RDD[BreezeVectorWithNorm]): KMeansModel = {
 
     val sc = data.sparkContext
 
+    val initStartTime = System.nanoTime()
+
     val centers = if (initializationMode == KMeans.RANDOM) {
       initRandom(data)
     } else {
       initKMeansParallel(data)
     }
 
+    val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
+    logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) +
+      " seconds.")
+
     val active = Array.fill(runs)(true)
     val costs = Array.fill(runs)(0.0)
 
     var activeRuns = new ArrayBuffer[Int] ++ (0 until runs)
     var iteration = 0
 
+    val iterationStartTime = System.nanoTime()
+
     // Execute iterations of Lloyd's algorithm until all runs have converged
     while (iteration < maxIterations && !activeRuns.isEmpty) {
-      type WeightedPoint = (DoubleMatrix, Long)
+      type WeightedPoint = (BV[Double], Long)
       def mergeContribs(p1: WeightedPoint, p2: WeightedPoint): WeightedPoint = {
-        (p1._1.addi(p2._1), p1._2 + p2._2)
+        (p1._1 += p2._1, p1._2 + p2._2)
       }
 
       val activeCenters = activeRuns.map(r => centers(r)).toArray
@@ -144,16 +162,18 @@ class KMeans private (
       val totalContribs = data.mapPartitions { points =>
         val runs = activeCenters.length
         val k = activeCenters(0).length
-        val dims = activeCenters(0)(0).length
+        val dims = activeCenters(0)(0).vector.length
 
-        val sums = Array.fill(runs, k)(new DoubleMatrix(dims))
+        val sums = Array.fill(runs, k)(BDV.zeros[Double](dims).asInstanceOf[BV[Double]])
         val counts = Array.fill(runs, k)(0L)
 
-        for (point <- points; (centers, runIndex) <- activeCenters.zipWithIndex) {
-          val (bestCenter, cost) = KMeans.findClosest(centers, point)
-          costAccums(runIndex) += cost
-          sums(runIndex)(bestCenter).addi(new DoubleMatrix(point))
-          counts(runIndex)(bestCenter) += 1
+        points.foreach { point =>
+          (0 until runs).foreach { i =>
+            val (bestCenter, cost) = KMeans.findClosest(activeCenters(i), point)
+            costAccums(i) += cost
+            sums(i)(bestCenter) += point.vector
+            counts(i)(bestCenter) += 1
+          }
         }
 
         val contribs = for (i <- 0 until runs; j <- 0 until k) yield {
@@ -165,15 +185,18 @@ class KMeans private (
       // Update the cluster centers and costs for each active run
       for ((run, i) <- activeRuns.zipWithIndex) {
         var changed = false
-        for (j <- 0 until k) {
+        var j = 0
+        while (j < k) {
           val (sum, count) = totalContribs((i, j))
           if (count != 0) {
-            val newCenter = sum.divi(count).data
-            if (MLUtils.squaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
+            sum /= count.toDouble
+            val newCenter = new BreezeVectorWithNorm(sum)
+            if (KMeans.fastSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {
               changed = true
             }
             centers(run)(j) = newCenter
           }
+          j += 1
         }
         if (!changed) {
           active(run) = false
@@ -186,17 +209,32 @@ class KMeans private (
       iteration += 1
     }
 
-    val bestRun = costs.zipWithIndex.min._2
-    new KMeansModel(centers(bestRun))
+    val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
+    logInfo(s"Iterations took " + "%.3f".format(iterationTimeInSeconds) + " seconds.")
+
+    if (iteration == maxIterations) {
+      logInfo(s"KMeans reached the max number of iterations: $maxIterations.")
+    } else {
+      logInfo(s"KMeans converged in $iteration iterations.")
+    }
+
+    val (minCost, bestRun) = costs.zipWithIndex.min
+
+    logInfo(s"The cost for the best run is $minCost.")
+
+    new KMeansModel(centers(bestRun).map(c => Vectors.fromBreeze(c.vector)))
   }
 
   /**
    * Initialize `runs` sets of cluster centers at random.
    */
-  private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = {
+  private def initRandom(data: RDD[BreezeVectorWithNorm])
+  : Array[Array[BreezeVectorWithNorm]] = {
     // Sample all the cluster centers in one pass to avoid repeated scans
     val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq
-    Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray)
+    Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).map { v =>
+      new BreezeVectorWithNorm(v.vector.toDenseVector, v.norm)
+    }.toArray)
   }
 
   /**
@@ -208,38 +246,43 @@ class KMeans private (
    *
    * The original paper can be found at http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf.
    */
-  private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = {
+  private def initKMeansParallel(data: RDD[BreezeVectorWithNorm])
+  : Array[Array[BreezeVectorWithNorm]] = {
     // Initialize each run's center to a random point
     val seed = new XORShiftRandom().nextInt()
     val sample = data.takeSample(true, runs, seed).toSeq
-    val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r)))
+    val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
 
     // On each step, sample 2 * k points on average for each run with probability proportional
     // to their squared distance from that run's current centers
-    for (step <- 0 until initializationSteps) {
-      val centerArrays = centers.map(_.toArray)
+    var step = 0
+    while (step < initializationSteps) {
       val sumCosts = data.flatMap { point =>
-        for (r <- 0 until runs) yield (r, KMeans.pointCost(centerArrays(r), point))
+        (0 until runs).map { r =>
+          (r, KMeans.pointCost(centers(r), point))
+        }
       }.reduceByKey(_ + _).collectAsMap()
       val chosen = data.mapPartitionsWithIndex { (index, points) =>
         val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
-        for {
-          p <- points
-          r <- 0 until runs
-          if rand.nextDouble() < KMeans.pointCost(centerArrays(r), p) * 2 * k / sumCosts(r)
-        } yield (r, p)
+        points.flatMap { p =>
+          (0 until runs).filter { r =>
+            rand.nextDouble() < 2.0 * KMeans.pointCost(centers(r), p) * k / sumCosts(r)
+          }.map((_, p))
+        }
       }.collect()
-      for ((r, p) <- chosen) {
-        centers(r) += p
+      chosen.foreach { case (r, p) =>
+        centers(r) += p.toDense
       }
+      step += 1
     }
 
     // Finally, we might have a set of more than k candidate centers for each run; weigh each
     // candidate by the number of points in the dataset mapping to it and run a local k-means++
     // on the weighted centers to pick just k of them
-    val centerArrays = centers.map(_.toArray)
     val weightMap = data.flatMap { p =>
-      for (r <- 0 until runs) yield ((r, KMeans.findClosest(centerArrays(r), p)._1), 1.0)
+      (0 until runs).map { r =>
+        ((r, KMeans.findClosest(centers(r), p)._1), 1.0)
+      }
     }.reduceByKey(_ + _).collectAsMap()
     val finalCenters = (0 until runs).map { r =>
       val myCenters = centers(r).toArray
@@ -256,63 +299,75 @@ class KMeans private (
  * Top-level methods for calling K-means clustering.
  */
 object KMeans {
+
   // Initialization mode names
   val RANDOM = "random"
   val K_MEANS_PARALLEL = "k-means||"
 
+  /**
+   * Trains a k-means model using the given set of parameters.
+   *
+   * @param data training points stored as `RDD[Array[Double]]`
+   * @param k number of clusters
+   * @param maxIterations max number of iterations
+   * @param runs number of parallel runs, defaults to 1. The best model is returned.
+   * @param initializationMode initialization model, either "random" or "k-means||" (default).
+   */
   def train(
-      data: RDD[Array[Double]],
+      data: RDD[Vector],
       k: Int,
       maxIterations: Int,
-      runs: Int,
-      initializationMode: String)
-    : KMeansModel =
-  {
+      runs: Int = 1,
+      initializationMode: String = K_MEANS_PARALLEL): KMeansModel = {
     new KMeans().setK(k)
-                .setMaxIterations(maxIterations)
-                .setRuns(runs)
-                .setInitializationMode(initializationMode)
-                .run(data)
-  }
-
-  def train(data: RDD[Array[Double]], k: Int, maxIterations: Int, runs: Int): KMeansModel = {
-    train(data, k, maxIterations, runs, K_MEANS_PARALLEL)
-  }
-
-  def train(data: RDD[Array[Double]], k: Int, maxIterations: Int): KMeansModel = {
-    train(data, k, maxIterations, 1, K_MEANS_PARALLEL)
+      .setMaxIterations(maxIterations)
+      .setRuns(runs)
+      .setInitializationMode(initializationMode)
+      .run(data)
   }
 
   /**
-   * Return the index of the closest point in `centers` to `point`, as well as its distance.
+   * Returns the index of the closest center to the given point, as well as the squared distance.
    */
-  private[mllib] def findClosest(centers: Array[Array[Double]], point: Array[Double])
-    : (Int, Double) =
-  {
+  private[mllib] def findClosest(
+      centers: TraversableOnce[BreezeVectorWithNorm],
+      point: BreezeVectorWithNorm): (Int, Double) = {
     var bestDistance = Double.PositiveInfinity
     var bestIndex = 0
-    for (i <- 0 until centers.length) {
-      val distance = MLUtils.squaredDistance(point, centers(i))
-      if (distance < bestDistance) {
-        bestDistance = distance
-        bestIndex = i
+    var i = 0
+    centers.foreach { center =>
+      // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
+      // distance computation.
+      var lowerBoundOfSqDist = center.norm - point.norm
+      lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
+      if (lowerBoundOfSqDist < bestDistance) {
+        val distance: Double = fastSquaredDistance(center, point)
+        if (distance < bestDistance) {
+          bestDistance = distance
+          bestIndex = i
+        }
       }
+      i += 1
     }
     (bestIndex, bestDistance)
   }
 
   /**
-   * Return the K-means cost of a given point against the given cluster centers.
+   * Returns the K-means cost of a given point against the given cluster centers.
    */
-  private[mllib] def pointCost(centers: Array[Array[Double]], point: Array[Double]): Double = {
-    var bestDistance = Double.PositiveInfinity
-    for (i <- 0 until centers.length) {
-      val distance = MLUtils.squaredDistance(point, centers(i))
-      if (distance < bestDistance) {
-        bestDistance = distance
-      }
-    }
-    bestDistance
+  private[mllib] def pointCost(
+      centers: TraversableOnce[BreezeVectorWithNorm],
+      point: BreezeVectorWithNorm): Double =
+    findClosest(centers, point)._2
+
+  /**
+   * Returns the squared Euclidean distance between two vectors computed by
+   * [[org.apache.spark.mllib.util.MLUtils#fastSquaredDistance]].
+   */
+  private[clustering]
+  def fastSquaredDistance(v1: BreezeVectorWithNorm, v2: BreezeVectorWithNorm)
+  : Double = {
+    MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
   }
 
   def main(args: Array[String]) {
@@ -323,14 +378,34 @@ object KMeans {
     val (master, inputFile, k, iters) = (args(0), args(1), args(2).toInt, args(3).toInt)
     val runs = if (args.length >= 5) args(4).toInt else 1
     val sc = new SparkContext(master, "KMeans")
-    val data = sc.textFile(inputFile).map(line => line.split(' ').map(_.toDouble)).cache()
+    val data = sc.textFile(inputFile)
+      .map(line => Vectors.dense(line.split(' ').map(_.toDouble)))
+      .cache()
     val model = KMeans.train(data, k, iters, runs)
     val cost = model.computeCost(data)
     println("Cluster centers:")
     for (c <- model.clusterCenters) {
-      println("  " + c.mkString(" "))
+      println("  " + c)
     }
     println("Cost: " + cost)
     System.exit(0)
   }
 }
+
+/**
+ * A breeze vector with its norm for fast distance computation.
+ *
+ * @see [[org.apache.spark.mllib.clustering.KMeans#fastSquaredDistance]]
+ */
+private[clustering]
+class BreezeVectorWithNorm(val vector: BV[Double], val norm: Double) extends Serializable {
+
+  def this(vector: BV[Double]) = this(vector, breezeNorm(vector, 2.0))
+
+  def this(array: Array[Double]) = this(new BDV[Double](array))
+
+  def this(v: Vector) = this(v.toBreeze)
+
+  /** Converts the vector to a dense vector. */
+  def toDense = new BreezeVectorWithNorm(vector.toDenseVector, norm)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index 980be93157..18abbf2758 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -19,24 +19,36 @@ package org.apache.spark.mllib.clustering
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.linalg.Vector
 
 /**
  * A clustering model for K-means. Each point belongs to the cluster with the closest center.
  */
-class KMeansModel(val clusterCenters: Array[Array[Double]]) extends Serializable {
+class KMeansModel(val clusterCenters: Array[Vector]) extends Serializable {
+
   /** Total number of clusters. */
   def k: Int = clusterCenters.length
 
-  /** Return the cluster index that a given point belongs to. */
-  def predict(point: Array[Double]): Int = {
-    KMeans.findClosest(clusterCenters, point)._1
+  /** Returns the cluster index that a given point belongs to. */
+  def predict(point: Vector): Int = {
+    KMeans.findClosest(clusterCentersWithNorm, new BreezeVectorWithNorm(point))._1
+  }
+
+  /** Maps given points to their cluster indices. */
+  def predict(points: RDD[Vector]): RDD[Int] = {
+    val centersWithNorm = clusterCentersWithNorm
+    points.map(p => KMeans.findClosest(centersWithNorm, new BreezeVectorWithNorm(p))._1)
   }
 
   /**
    * Return the K-means cost (sum of squared distances of points to their nearest center) for this
    * model on the given data.
    */
-  def computeCost(data: RDD[Array[Double]]): Double = {
-    data.map(p => KMeans.pointCost(clusterCenters, p)).sum()
+  def computeCost(data: RDD[Vector]): Double = {
+    val centersWithNorm = clusterCentersWithNorm
+    data.map(p => KMeans.pointCost(centersWithNorm, new BreezeVectorWithNorm(p))).sum()
   }
+
+  private def clusterCentersWithNorm: Iterable[BreezeVectorWithNorm] =
+    clusterCenters.map(new BreezeVectorWithNorm(_))
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
index baf8251d8f..2e3a4ce783 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala
@@ -19,35 +19,37 @@ package org.apache.spark.mllib.clustering
 
 import scala.util.Random
 
-import org.jblas.{DoubleMatrix, SimpleBlas}
+import breeze.linalg.{Vector => BV, DenseVector => BDV, norm => breezeNorm}
+
+import org.apache.spark.Logging
 
 /**
  * An utility object to run K-means locally. This is private to the ML package because it's used
  * in the initialization of KMeans but not meant to be publicly exposed.
  */
-private[mllib] object LocalKMeans {
+private[mllib] object LocalKMeans extends Logging {
+
   /**
    * Run K-means++ on the weighted point set `points`. This first does the K-means++
-   * initialization procedure and then roudns of Lloyd's algorithm.
+   * initialization procedure and then rounds of Lloyd's algorithm.
    */
   def kMeansPlusPlus(
       seed: Int,
-      points: Array[Array[Double]],
+      points: Array[BreezeVectorWithNorm],
       weights: Array[Double],
       k: Int,
-      maxIterations: Int)
-    : Array[Array[Double]] =
-  {
+      maxIterations: Int
+  ): Array[BreezeVectorWithNorm] = {
     val rand = new Random(seed)
-    val dimensions = points(0).length
-    val centers = new Array[Array[Double]](k)
+    val dimensions = points(0).vector.length
+    val centers = new Array[BreezeVectorWithNorm](k)
 
-    // Initialize centers by sampling using the k-means++ procedure
-    centers(0) = pickWeighted(rand, points, weights)
+    // Initialize centers by sampling using the k-means++ procedure.
+    centers(0) = pickWeighted(rand, points, weights).toDense
     for (i <- 1 until k) {
       // Pick the next center with a probability proportional to cost under current centers
-      val curCenters = centers.slice(0, i)
-      val sum = points.zip(weights).map { case (p, w) =>
+      val curCenters = centers.view.take(i)
+      val sum = points.view.zip(weights).map { case (p, w) =>
         w * KMeans.pointCost(curCenters, p)
       }.sum
       val r = rand.nextDouble() * sum
@@ -57,7 +59,7 @@ private[mllib] object LocalKMeans {
         cumulativeScore += weights(j) * KMeans.pointCost(curCenters, points(j))
         j += 1
       }
-      centers(i) = points(j-1)
+      centers(i) = points(j-1).toDense
     }
 
     // Run up to maxIterations iterations of Lloyd's algorithm
@@ -66,29 +68,43 @@ private[mllib] object LocalKMeans {
     var moved = true
     while (moved && iteration < maxIterations) {
       moved = false
-      val sums = Array.fill(k)(new DoubleMatrix(dimensions))
       val counts = Array.fill(k)(0.0)
-      for ((p, i) <- points.zipWithIndex) {
+      val sums = Array.fill(k)(
+        BDV.zeros[Double](dimensions).asInstanceOf[BV[Double]]
+      )
+      var i = 0
+      while (i < points.length) {
+        val p = points(i)
         val index = KMeans.findClosest(centers, p)._1
-        SimpleBlas.axpy(weights(i), new DoubleMatrix(p), sums(index))
+        breeze.linalg.axpy(weights(i), p.vector, sums(index))
         counts(index) += weights(i)
         if (index != oldClosest(i)) {
           moved = true
           oldClosest(i) = index
         }
+        i += 1
       }
       // Update centers
-      for (i <- 0 until k) {
-        if (counts(i) == 0.0) {
+      var j = 0
+      while (j < k) {
+        if (counts(j) == 0.0) {
           // Assign center to a random point
-          centers(i) = points(rand.nextInt(points.length))
+          centers(j) = points(rand.nextInt(points.length)).toDense
         } else {
-          centers(i) = sums(i).divi(counts(i)).data
+          sums(j) /= counts(j)
+          centers(j) = new BreezeVectorWithNorm(sums(j))
         }
+        j += 1
       }
       iteration += 1
     }
 
+    if (iteration == maxIterations) {
+      logInfo(s"Local KMeans++ reached the max number of iterations: $maxIterations.")
+    } else {
+      logInfo(s"Local KMeans++ converged in $iteration iterations.")
+    }
+
     centers
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
new file mode 100644
index 0000000000..01c1501548
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg
+
+import java.lang.{Iterable => JavaIterable, Integer => JavaInteger, Double => JavaDouble}
+import java.util.Arrays
+
+import scala.annotation.varargs
+import scala.collection.JavaConverters._
+
+import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV}
+
+/**
+ * Represents a numeric vector, whose index type is Int and value type is Double.
+ */
+trait Vector extends Serializable {
+
+  /**
+   * Size of the vector.
+   */
+  def size: Int
+
+  /**
+   * Converts the instance to a double array.
+   */
+  def toArray: Array[Double]
+
+  override def equals(other: Any): Boolean = {
+    other match {
+      case v: Vector =>
+        Arrays.equals(this.toArray, v.toArray)
+      case _ => false
+    }
+  }
+
+  override def hashCode(): Int = Arrays.hashCode(this.toArray)
+
+  /**
+   * Converts the instance to a breeze vector.
+   */
+  private[mllib] def toBreeze: BV[Double]
+}
+
+/**
+ * Factory methods for [[org.apache.spark.mllib.linalg.Vector]].
+ */
+object Vectors {
+
+  /**
+   * Creates a dense vector.
+   */
+  @varargs
+  def dense(firstValue: Double, otherValues: Double*): Vector =
+    new DenseVector((firstValue +: otherValues).toArray)
+
+  // A dummy implicit is used to avoid signature collision with the one generated by @varargs.
+  /**
+   * Creates a dense vector from a double array.
+   */
+  def dense(values: Array[Double]): Vector = new DenseVector(values)
+
+  /**
+   * Creates a sparse vector providing its index array and value array.
+   *
+   * @param size vector size.
+   * @param indices index array, must be strictly increasing.
+   * @param values value array, must have the same length as indices.
+   */
+  def sparse(size: Int, indices: Array[Int], values: Array[Double]): Vector =
+    new SparseVector(size, indices, values)
+
+  /**
+   * Creates a sparse vector using unordered (index, value) pairs.
+   *
+   * @param size vector size.
+   * @param elements vector elements in (index, value) pairs.
+   */
+  def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = {
+    require(size > 0)
+
+    val (indices, values) = elements.sortBy(_._1).unzip
+    var prev = -1
+    indices.foreach { i =>
+      require(prev < i, s"Found duplicate indices: $i.")
+      prev = i
+    }
+    require(prev < size)
+
+    new SparseVector(size, indices.toArray, values.toArray)
+  }
+
+  /**
+   * Creates a sparse vector using unordered (index, value) pairs in a Java friendly way.
+   *
+   * @param size vector size.
+   * @param elements vector elements in (index, value) pairs.
+   */
+  def sparse(size: Int, elements: JavaIterable[(JavaInteger, JavaDouble)]): Vector = {
+    sparse(size, elements.asScala.map { case (i, x) =>
+      (i.intValue(), x.doubleValue())
+    }.toSeq)
+  }
+
+  /**
+   * Creates a vector instance from a breeze vector.
+   */
+  private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = {
+    breezeVector match {
+      case v: BDV[Double] =>
+        require(v.offset == 0, s"Do not support non-zero offset ${v.offset}.")
+        require(v.stride == 1, s"Do not support stride other than 1, but got ${v.stride}.")
+        new DenseVector(v.data)
+      case v: BSV[Double] =>
+        new SparseVector(v.length, v.index, v.data)
+      case v: BV[_] =>
+        sys.error("Unsupported Breeze vector type: " + v.getClass.getName)
+    }
+  }
+}
+
+/**
+ * A dense vector represented by a value array.
+ */
+class DenseVector(val values: Array[Double]) extends Vector {
+
+  override def size: Int = values.length
+
+  override def toString: String = values.mkString("[", ",", "]")
+
+  override def toArray: Array[Double] = values
+
+  private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values)
+}
+
+/**
+ * A sparse vector represented by an index array and an value array.
+ *
+ * @param n size of the vector.
+ * @param indices index array, assume to be strictly increasing.
+ * @param values value array, must have the same length as the index array.
+ */
+class SparseVector(val n: Int, val indices: Array[Int], val values: Array[Double]) extends Vector {
+
+  override def size: Int = n
+
+  override def toString: String = {
+    "(" + n + "," + indices.zip(values).mkString("[", "," ,"]") + ")"
+  }
+
+  override def toArray: Array[Double] = {
+    val data = new Array[Double](n)
+    var i = 0
+    val nnz = indices.length
+    while (i < nnz) {
+      data(indices(i)) = values(i)
+      i += 1
+    }
+    data
+  }
+
+  private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, n)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDs.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDs.scala
new file mode 100644
index 0000000000..9096d6a1a1
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDs.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.rdd
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.{Vectors, Vector}
+
+/**
+ * Factory methods for `RDD[Vector]`.
+ */
+object VectorRDDs {
+
+  /**
+   * Converts an `RDD[Array[Double]]` to `RDD[Vector]`.
+   */
+  def fromArrayRDD(rdd: RDD[Array[Double]]): RDD[Vector] = rdd.map(v => Vectors.dense(v))
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 64c6136a8b..08cd9ab055 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -22,13 +22,24 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.SparkContext._
 
 import org.jblas.DoubleMatrix
+
 import org.apache.spark.mllib.regression.LabeledPoint
 
+import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance}
+
 /**
  * Helper methods to load, save and pre-process data used in ML Lib.
  */
 object MLUtils {
 
+  private[util] lazy val EPSILON = {
+    var eps = 1.0
+    while ((1.0 + (eps / 2.0)) != 1.0) {
+      eps /= 2.0
+    }
+    eps
+  }
+
   /**
    * Load labeled data from a file. The data format used here is
    * <L>, <f1> <f2> ...
@@ -106,18 +117,46 @@ object MLUtils {
   }
 
   /**
-   * Return the squared Euclidean distance between two vectors.
+   * Returns the squared Euclidean distance between two vectors. The following formula will be used
+   * if it does not introduce too much numerical error:
+   * <pre>
+   *   \|a - b\|_2^2 = \|a\|_2^2 + \|b\|_2^2 - 2 a^T b.
+   * </pre>
+   * When both vector norms are given, this is faster than computing the squared distance directly,
+   * especially when one of the vectors is a sparse vector.
+   *
+   * @param v1 the first vector
+   * @param norm1 the norm of the first vector, non-negative
+   * @param v2 the second vector
+   * @param norm2 the norm of the second vector, non-negative
+   * @param precision desired relative precision for the squared distance
+   * @return squared distance between v1 and v2 within the specified precision
    */
-  def squaredDistance(v1: Array[Double], v2: Array[Double]): Double = {
-    if (v1.length != v2.length) {
-      throw new IllegalArgumentException("Vector sizes don't match")
-    }
-    var i = 0
-    var sum = 0.0
-    while (i < v1.length) {
-      sum += (v1(i) - v2(i)) * (v1(i) - v2(i))
-      i += 1
+  private[mllib] def fastSquaredDistance(
+      v1: BV[Double],
+      norm1: Double,
+      v2: BV[Double],
+      norm2: Double,
+      precision: Double = 1e-6): Double = {
+    val n = v1.size
+    require(v2.size == n)
+    require(norm1 >= 0.0 && norm2 >= 0.0)
+    val sumSquaredNorm = norm1 * norm1 + norm2 * norm2
+    val normDiff = norm1 - norm2
+    var sqDist = 0.0
+    val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)
+    if (precisionBound1 < precision) {
+      sqDist = sumSquaredNorm - 2.0 * v1.dot(v2)
+    } else if (v1.isInstanceOf[BSV[Double]] || v2.isInstanceOf[BSV[Double]]) {
+      val dot = v1.dot(v2)
+      sqDist = math.max(sumSquaredNorm - 2.0 * dot, 0.0)
+      val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 * math.abs(dot)) / (sqDist + EPSILON)
+      if (precisionBound2 > precision) {
+        sqDist = breezeSquaredDistance(v1, v2)
+      }
+    } else {
+      sqDist = breezeSquaredDistance(v1, v2)
     }
-    sum
+    sqDist
   }
 }
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
index 33b99f4bd3..49a614bd90 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaKMeansSuite.java
@@ -18,16 +18,19 @@
 package org.apache.spark.mllib.clustering;
 
 import java.io.Serializable;
-import java.util.ArrayList;
 import java.util.List;
 
 import org.junit.After;
-import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
+import static org.junit.Assert.*;
+
+import com.google.common.collect.Lists;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
 
 public class JavaKMeansSuite implements Serializable {
   private transient JavaSparkContext sc;
@@ -44,72 +47,45 @@ public class JavaKMeansSuite implements Serializable {
     System.clearProperty("spark.driver.port");
   }
 
-  // L1 distance between two points
-  double distance1(double[] v1, double[] v2) {
-    double distance = 0.0;
-    for (int i = 0; i < v1.length; ++i) {
-      distance = Math.max(distance, Math.abs(v1[i] - v2[i]));
-    }
-    return distance;
-  }
-
-  // Assert that two sets of points are equal, within EPSILON tolerance
-  void assertSetsEqual(double[][] v1, double[][] v2) {
-    double EPSILON = 1e-4;
-    Assert.assertTrue(v1.length == v2.length);
-    for (int i = 0; i < v1.length; ++i) {
-      double minDistance = Double.MAX_VALUE;
-      for (int j = 0; j < v2.length; ++j) {
-        minDistance = Math.min(minDistance, distance1(v1[i], v2[j]));
-      }
-      Assert.assertTrue(minDistance <= EPSILON);
-    }
-
-    for (int i = 0; i < v2.length; ++i) {
-      double minDistance = Double.MAX_VALUE;
-      for (int j = 0; j < v1.length; ++j) {
-        minDistance = Math.min(minDistance, distance1(v2[i], v1[j]));
-      }
-      Assert.assertTrue(minDistance <= EPSILON);
-    }
-  }
-
-
   @Test
   public void runKMeansUsingStaticMethods() {
-    List<double[]> points = new ArrayList<double[]>();
-    points.add(new double[]{1.0, 2.0, 6.0});
-    points.add(new double[]{1.0, 3.0, 0.0});
-    points.add(new double[]{1.0, 4.0, 6.0});
+    List<Vector> points = Lists.newArrayList(
+      Vectors.dense(1.0, 2.0, 6.0),
+      Vectors.dense(1.0, 3.0, 0.0),
+      Vectors.dense(1.0, 4.0, 6.0)
+    );
 
-    double[][] expectedCenter = { {1.0, 3.0, 4.0} };
+    Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
 
-    JavaRDD<double[]> data = sc.parallelize(points, 2);
-    KMeansModel model = KMeans.train(data.rdd(), 1, 1);
-    assertSetsEqual(model.clusterCenters(), expectedCenter);
+    JavaRDD<Vector> data = sc.parallelize(points, 2);
+    KMeansModel model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.K_MEANS_PARALLEL());
+    assertEquals(1, model.clusterCenters().length);
+    assertEquals(expectedCenter, model.clusterCenters()[0]);
 
     model = KMeans.train(data.rdd(), 1, 1, 1, KMeans.RANDOM());
-    assertSetsEqual(model.clusterCenters(), expectedCenter);
+    assertEquals(expectedCenter, model.clusterCenters()[0]);
   }
 
   @Test
   public void runKMeansUsingConstructor() {
-    List<double[]> points = new ArrayList<double[]>();
-    points.add(new double[]{1.0, 2.0, 6.0});
-    points.add(new double[]{1.0, 3.0, 0.0});
-    points.add(new double[]{1.0, 4.0, 6.0});
+    List<Vector> points = Lists.newArrayList(
+      Vectors.dense(1.0, 2.0, 6.0),
+      Vectors.dense(1.0, 3.0, 0.0),
+      Vectors.dense(1.0, 4.0, 6.0)
+    );
 
-    double[][] expectedCenter = { {1.0, 3.0, 4.0} };
+    Vector expectedCenter = Vectors.dense(1.0, 3.0, 4.0);
 
-    JavaRDD<double[]> data = sc.parallelize(points, 2);
+    JavaRDD<Vector> data = sc.parallelize(points, 2);
     KMeansModel model = new KMeans().setK(1).setMaxIterations(5).run(data.rdd());
-    assertSetsEqual(model.clusterCenters(), expectedCenter);
-
-    model = new KMeans().setK(1)
-                        .setMaxIterations(1)
-                        .setRuns(1)
-                        .setInitializationMode(KMeans.RANDOM())
-                        .run(data.rdd());
-    assertSetsEqual(model.clusterCenters(), expectedCenter);
+    assertEquals(1, model.clusterCenters().length);
+    assertEquals(expectedCenter, model.clusterCenters()[0]);
+
+    model = new KMeans()
+      .setK(1)
+      .setMaxIterations(1)
+      .setInitializationMode(KMeans.RANDOM())
+      .run(data.rdd());
+    assertEquals(expectedCenter, model.clusterCenters()[0]);
   }
 }
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
new file mode 100644
index 0000000000..2c4d795f96
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg;
+
+import java.io.Serializable;
+
+import com.google.common.collect.Lists;
+
+import scala.Tuple2;
+
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+public class JavaVectorsSuite implements Serializable {
+
+  @Test
+  public void denseArrayConstruction() {
+    Vector v = Vectors.dense(1.0, 2.0, 3.0);
+    assertArrayEquals(new double[]{1.0, 2.0, 3.0}, v.toArray(), 0.0);
+  }
+
+  @Test
+  public void sparseArrayConstruction() {
+    Vector v = Vectors.sparse(3, Lists.newArrayList(
+        new Tuple2<Integer, Double>(0, 2.0),
+        new Tuple2<Integer, Double>(2, 3.0)));
+    assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0);
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index 4ef1d1f64f..560a4ad71a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -17,127 +17,139 @@
 
 package org.apache.spark.mllib.clustering
 
-
-import org.scalatest.BeforeAndAfterAll
 import org.scalatest.FunSuite
 
 import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.linalg.Vectors
 
 class KMeansSuite extends FunSuite with LocalSparkContext {
 
-  val EPSILON = 1e-4
-
   import KMeans.{RANDOM, K_MEANS_PARALLEL}
 
-  def prettyPrint(point: Array[Double]): String = point.mkString("(", ", ", ")")
-
-  def prettyPrint(points: Array[Array[Double]]): String = {
-    points.map(prettyPrint).mkString("(", "; ", ")")
-  }
-
-  // L1 distance between two points
-  def distance1(v1: Array[Double], v2: Array[Double]): Double = {
-    v1.zip(v2).map{ case (a, b) => math.abs(a-b) }.max
-  }
-
-  // Assert that two vectors are equal within tolerance EPSILON
-  def assertEqual(v1: Array[Double], v2: Array[Double]) {
-    def errorMessage = prettyPrint(v1) + " did not equal " + prettyPrint(v2)
-    assert(v1.length == v2.length, errorMessage)
-    assert(distance1(v1, v2) <= EPSILON, errorMessage)
-  }
-
-  // Assert that two sets of points are equal, within EPSILON tolerance
-  def assertSetsEqual(set1: Array[Array[Double]], set2: Array[Array[Double]]) {
-    def errorMessage = prettyPrint(set1) + " did not equal " + prettyPrint(set2)
-    assert(set1.length == set2.length, errorMessage)
-    for (v <- set1) {
-      val closestDistance = set2.map(w => distance1(v, w)).min
-      if (closestDistance > EPSILON) {
-        fail(errorMessage)
-      }
-    }
-    for (v <- set2) {
-      val closestDistance = set1.map(w => distance1(v, w)).min
-      if (closestDistance > EPSILON) {
-        fail(errorMessage)
-      }
-    }
-  }
-
   test("single cluster") {
     val data = sc.parallelize(Array(
-      Array(1.0, 2.0, 6.0),
-      Array(1.0, 3.0, 0.0),
-      Array(1.0, 4.0, 6.0)
+      Vectors.dense(1.0, 2.0, 6.0),
+      Vectors.dense(1.0, 3.0, 0.0),
+      Vectors.dense(1.0, 4.0, 6.0)
     ))
 
+    val center = Vectors.dense(1.0, 3.0, 4.0)
+
     // No matter how many runs or iterations we use, we should get one cluster,
     // centered at the mean of the points
 
     var model = KMeans.train(data, k=1, maxIterations=1)
-    assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+    assert(model.clusterCenters.head === center)
 
     model = KMeans.train(data, k=1, maxIterations=2)
-    assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+    assert(model.clusterCenters.head === center)
 
     model = KMeans.train(data, k=1, maxIterations=5)
-    assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+    assert(model.clusterCenters.head === center)
 
     model = KMeans.train(data, k=1, maxIterations=1, runs=5)
-    assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+    assert(model.clusterCenters.head === center)
 
     model = KMeans.train(data, k=1, maxIterations=1, runs=5)
-    assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+    assert(model.clusterCenters.head === center)
 
     model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
-    assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+    assert(model.clusterCenters.head === center)
 
     model = KMeans.train(
       data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
-    assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+    assert(model.clusterCenters.head === center)
   }
 
   test("single cluster with big dataset") {
     val smallData = Array(
-      Array(1.0, 2.0, 6.0),
-      Array(1.0, 3.0, 0.0),
-      Array(1.0, 4.0, 6.0)
+      Vectors.dense(1.0, 2.0, 6.0),
+      Vectors.dense(1.0, 3.0, 0.0),
+      Vectors.dense(1.0, 4.0, 6.0)
     )
     val data = sc.parallelize((1 to 100).flatMap(_ => smallData), 4)
 
     // No matter how many runs or iterations we use, we should get one cluster,
     // centered at the mean of the points
 
+    val center = Vectors.dense(1.0, 3.0, 4.0)
+
     var model = KMeans.train(data, k=1, maxIterations=1)
-    assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+    assert(model.clusterCenters.size === 1)
+    assert(model.clusterCenters.head === center)
 
     model = KMeans.train(data, k=1, maxIterations=2)
-    assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+    assert(model.clusterCenters.head === center)
 
     model = KMeans.train(data, k=1, maxIterations=5)
-    assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+    assert(model.clusterCenters.head === center)
 
     model = KMeans.train(data, k=1, maxIterations=1, runs=5)
-    assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+    assert(model.clusterCenters.head === center)
 
     model = KMeans.train(data, k=1, maxIterations=1, runs=5)
-    assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+    assert(model.clusterCenters.head === center)
 
     model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
-    assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+    assert(model.clusterCenters.head === center)
 
     model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
-    assertSetsEqual(model.clusterCenters, Array(Array(1.0, 3.0, 4.0)))
+    assert(model.clusterCenters.head === center)
+  }
+
+  test("single cluster with sparse data") {
+
+    val n = 10000
+    val data = sc.parallelize((1 to 100).flatMap { i =>
+      val x = i / 1000.0
+      Array(
+        Vectors.sparse(n, Seq((0, 1.0 + x), (1, 2.0), (2, 6.0))),
+        Vectors.sparse(n, Seq((0, 1.0 - x), (1, 2.0), (2, 6.0))),
+        Vectors.sparse(n, Seq((0, 1.0), (1, 3.0 + x))),
+        Vectors.sparse(n, Seq((0, 1.0), (1, 3.0 - x))),
+        Vectors.sparse(n, Seq((0, 1.0), (1, 4.0), (2, 6.0 + x))),
+        Vectors.sparse(n, Seq((0, 1.0), (1, 4.0), (2, 6.0 - x)))
+      )
+    }, 4)
+
+    data.persist()
+
+    // No matter how many runs or iterations we use, we should get one cluster,
+    // centered at the mean of the points
+
+    val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0)))
+
+    var model = KMeans.train(data, k=1, maxIterations=1)
+    assert(model.clusterCenters.head === center)
+
+    model = KMeans.train(data, k=1, maxIterations=2)
+    assert(model.clusterCenters.head === center)
+
+    model = KMeans.train(data, k=1, maxIterations=5)
+    assert(model.clusterCenters.head === center)
+
+    model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+    assert(model.clusterCenters.head === center)
+
+    model = KMeans.train(data, k=1, maxIterations=1, runs=5)
+    assert(model.clusterCenters.head === center)
+
+    model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=RANDOM)
+    assert(model.clusterCenters.head === center)
+
+    model = KMeans.train(data, k=1, maxIterations=1, runs=1, initializationMode=K_MEANS_PARALLEL)
+    assert(model.clusterCenters.head === center)
+
+    data.unpersist()
   }
 
   test("k-means|| initialization") {
-    val points = Array(
-      Array(1.0, 2.0, 6.0),
-      Array(1.0, 3.0, 0.0),
-      Array(1.0, 4.0, 6.0),
-      Array(1.0, 0.0, 1.0),
-      Array(1.0, 1.0, 1.0)
+    val points = Seq(
+      Vectors.dense(1.0, 2.0, 6.0),
+      Vectors.dense(1.0, 3.0, 0.0),
+      Vectors.dense(1.0, 4.0, 6.0),
+      Vectors.dense(1.0, 0.0, 1.0),
+      Vectors.dense(1.0, 1.0, 1.0)
     )
     val rdd = sc.parallelize(points)
 
@@ -146,14 +158,39 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
     // unselected point as long as it hasn't yet selected all of them
 
     var model = KMeans.train(rdd, k=5, maxIterations=1)
-    assertSetsEqual(model.clusterCenters, points)
+    assert(Set(model.clusterCenters: _*) === Set(points: _*))
 
     // Iterations of Lloyd's should not change the answer either
     model = KMeans.train(rdd, k=5, maxIterations=10)
-    assertSetsEqual(model.clusterCenters, points)
+    assert(Set(model.clusterCenters: _*) === Set(points: _*))
 
     // Neither should more runs
     model = KMeans.train(rdd, k=5, maxIterations=10, runs=5)
-    assertSetsEqual(model.clusterCenters, points)
+    assert(Set(model.clusterCenters: _*) === Set(points: _*))
+  }
+
+  test("two clusters") {
+    val points = Seq(
+      Vectors.dense(0.0, 0.0),
+      Vectors.dense(0.0, 0.1),
+      Vectors.dense(0.1, 0.0),
+      Vectors.dense(9.0, 0.0),
+      Vectors.dense(9.0, 0.2),
+      Vectors.dense(9.2, 0.0)
+    )
+    val rdd = sc.parallelize(points, 3)
+
+    for (initMode <- Seq(RANDOM, K_MEANS_PARALLEL)) {
+      // Two iterations are sufficient no matter where the initial centers are.
+      val model = KMeans.train(rdd, k = 2, maxIterations = 2, runs = 1, initMode)
+
+      val predicts = model.predict(rdd).collect()
+
+      assert(predicts(0) === predicts(1))
+      assert(predicts(0) === predicts(2))
+      assert(predicts(3) === predicts(4))
+      assert(predicts(3) === predicts(5))
+      assert(predicts(0) != predicts(3))
+    }
   }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
new file mode 100644
index 0000000000..aacaa30084
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeVectorConversionSuite.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg
+
+import org.scalatest.FunSuite
+
+import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
+
+/**
+ * Test Breeze vector conversions.
+ */
+class BreezeVectorConversionSuite extends FunSuite {
+
+  val arr = Array(0.1, 0.2, 0.3, 0.4)
+  val n = 20
+  val indices = Array(0, 3, 5, 10, 13)
+  val values = Array(0.1, 0.5, 0.3, -0.8, -1.0)
+
+  test("dense to breeze") {
+    val vec = Vectors.dense(arr)
+    assert(vec.toBreeze === new BDV[Double](arr))
+  }
+
+  test("sparse to breeze") {
+    val vec = Vectors.sparse(n, indices, values)
+    assert(vec.toBreeze === new BSV[Double](indices, values, n))
+  }
+
+  test("dense breeze to vector") {
+    val breeze = new BDV[Double](arr)
+    val vec = Vectors.fromBreeze(breeze).asInstanceOf[DenseVector]
+    assert(vec.size === arr.length)
+    assert(vec.values.eq(arr), "should not copy data")
+  }
+
+  test("sparse breeze to vector") {
+    val breeze = new BSV[Double](indices, values, n)
+    val vec = Vectors.fromBreeze(breeze).asInstanceOf[SparseVector]
+    assert(vec.size === n)
+    assert(vec.indices.eq(indices), "should not copy data")
+    assert(vec.values.eq(values), "should not copy data")
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
new file mode 100644
index 0000000000..8a200310e0
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg
+
+import org.scalatest.FunSuite
+
+class VectorsSuite extends FunSuite {
+
+  val arr = Array(0.1, 0.0, 0.3, 0.4)
+  val n = 4
+  val indices = Array(0, 2, 3)
+  val values = Array(0.1, 0.3, 0.4)
+
+  test("dense vector construction with varargs") {
+    val vec = Vectors.dense(arr).asInstanceOf[DenseVector]
+    assert(vec.size === arr.length)
+    assert(vec.values.eq(arr))
+  }
+
+  test("dense vector construction from a double array") {
+   val vec = Vectors.dense(arr).asInstanceOf[DenseVector]
+    assert(vec.size === arr.length)
+    assert(vec.values.eq(arr))
+  }
+
+  test("sparse vector construction") {
+    val vec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
+    assert(vec.size === n)
+    assert(vec.indices.eq(indices))
+    assert(vec.values.eq(values))
+  }
+
+  test("sparse vector construction with unordered elements") {
+    val vec = Vectors.sparse(n, indices.zip(values).reverse).asInstanceOf[SparseVector]
+    assert(vec.size === n)
+    assert(vec.indices === indices)
+    assert(vec.values === values)
+  }
+
+  test("dense to array") {
+    val vec = Vectors.dense(arr).asInstanceOf[DenseVector]
+    assert(vec.toArray.eq(arr))
+  }
+
+  test("sparse to array") {
+    val vec = Vectors.sparse(n, indices, values).asInstanceOf[SparseVector]
+    assert(vec.toArray === arr)
+  }
+
+  test("vector equals") {
+    val dv1 = Vectors.dense(arr.clone())
+    val dv2 = Vectors.dense(arr.clone())
+    val sv1 = Vectors.sparse(n, indices.clone(), values.clone())
+    val sv2 = Vectors.sparse(n, indices.clone(), values.clone())
+
+    val vectors = Seq(dv1, dv2, sv1, sv2)
+
+    for (v <- vectors; u <- vectors) {
+      assert(v === u)
+      assert(v.## === u.##)
+    }
+
+    val another = Vectors.dense(0.1, 0.2, 0.3, 0.4)
+
+    for (v <- vectors) {
+      assert(v != another)
+      assert(v.## != another.##)
+    }
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDsSuite.scala
new file mode 100644
index 0000000000..692f025e95
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDsSuite.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.rdd
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.LocalSparkContext
+
+class VectorRDDsSuite extends FunSuite with LocalSparkContext {
+
+  test("from array rdd") {
+    val data = Seq(Array(1.0, 2.0), Array(3.0, 4.0))
+    val arrayRdd = sc.parallelize(data, 2)
+    val vectorRdd = VectorRDDs.fromArrayRDD(arrayRdd)
+    assert(arrayRdd.collect().map(v => Vectors.dense(v)) === vectorRdd.collect())
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
index 7d840043e5..212fbe9288 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 package org.apache.spark.mllib.util
 
 import org.scalatest.Suite
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
new file mode 100644
index 0000000000..60f053b381
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import org.scalatest.FunSuite
+
+import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm,
+  squaredDistance => breezeSquaredDistance}
+
+import org.apache.spark.mllib.util.MLUtils._
+
+class MLUtilsSuite extends FunSuite {
+
+  test("epsilon computation") {
+    assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.")
+    assert(1.0 + EPSILON / 2.0 === 1.0, s"EPSILON is too big: $EPSILON.")
+  }
+
+  test("fast squared distance") {
+    val a = (30 to 0 by -1).map(math.pow(2.0, _)).toArray
+    val n = a.length
+    val v1 = new BDV[Double](a)
+    val norm1 = breezeNorm(v1, 2.0)
+    val precision = 1e-6
+    for (m <- 0 until n) {
+      val indices = (0 to m).toArray
+      val values = indices.map(i => a(i))
+      val v2 = new BSV[Double](indices, values, n)
+      val norm2 = breezeNorm(v2, 2.0)
+      val squaredDist = breezeSquaredDistance(v1, v2)
+      val fastSquaredDist1 = fastSquaredDistance(v1, norm1, v2, norm2, precision)
+      assert((fastSquaredDist1 - squaredDist) <= precision * squaredDist, s"failed with m = $m")
+      val fastSquaredDist2 = fastSquaredDistance(v1, norm1, v2.toDenseVector, norm2, precision)
+      assert((fastSquaredDist2 - squaredDist) <= precision * squaredDist, s"failed with m = $m")
+    }
+  }
+}
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index b08fb26adf..1969486f79 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -366,7 +366,8 @@ object SparkBuild extends Build {
   def mllibSettings = sharedSettings ++ Seq(
     name := "spark-mllib",
     libraryDependencies ++= Seq(
-      "org.jblas" % "jblas" % "1.2.3"
+      "org.jblas" % "jblas" % "1.2.3",
+      "org.scalanlp" %% "breeze" % "0.7"
     )
   )
 
-- 
GitLab