Skip to content
Snippets Groups Projects
Commit 7b0ed797 authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Xiangrui Meng
Browse files

[SPARK-5419][Mllib] Fix the logic in Vectors.sqdist

The current implementation in Vectors.sqdist is not efficient because of allocating temp arrays. There is also a bug in the code `v1.indices.length / v1.size < 0.5`. This pr fixes the bug and refactors sqdist without allocating new arrays.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #4217 from viirya/fix_sqdist and squashes the following commits:

e8b0b3d [Liang-Chi Hsieh] For review comments.
314c424 [Liang-Chi Hsieh] Fix sqdist bug.
parent d6894b1c
No related branches found
No related tags found
No related merge requests found
......@@ -371,18 +371,23 @@ object Vectors {
squaredDistance += score * score
}
case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 =>
case (v1: SparseVector, v2: DenseVector) =>
squaredDistance = sqdist(v1, v2)
case (v1: DenseVector, v2: SparseVector) if v2.indices.length / v2.size < 0.5 =>
case (v1: DenseVector, v2: SparseVector) =>
squaredDistance = sqdist(v2, v1)
// When a SparseVector is approximately dense, we treat it as a DenseVector
case (v1, v2) =>
squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0){ (distance, elems) =>
val score = elems._1 - elems._2
distance + score * score
case (DenseVector(vv1), DenseVector(vv2)) =>
var kv = 0
val sz = vv1.size
while (kv < sz) {
val score = vv1(kv) - vv2(kv)
squaredDistance += score * score
kv += 1
}
case _ =>
throw new IllegalArgumentException("Do not support vector type " + v1.getClass +
" and " + v2.getClass)
}
squaredDistance
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment