diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 5b079fce3a83d06d39f5997402666f82fab45925..7e6c3679704c5d97d55bd4d7c1167a56746f13ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -540,14 +540,16 @@ class Word2VecModel private[spark] ( val cosineVec = Array.fill[Float](numWords)(0) val alpha: Float = 1 val beta: Float = 0 - + // Normalize input vector before blas.sgemv to avoid Inf value + val vecNorm = blas.snrm2(vectorSize, fVector, 1) + if (vecNorm != 0.0f) { + blas.sscal(vectorSize, 1 / vecNorm, fVector, 0, 1) + } blas.sgemv( "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) - // Need not divide with the norm of the given vector since it is constant. val cosVec = cosineVec.map(_.toDouble) var ind = 0 - val vecNorm = blas.snrm2(vectorSize, fVector, 1) while (ind < numWords) { val norm = wordVecNorms(ind) if (norm == 0.0) { @@ -557,17 +559,13 @@ class Word2VecModel private[spark] ( } ind += 1 } - var topResults = wordList.zip(cosVec) + + wordList.zip(cosVec) .toSeq .sortBy(-_._2) .take(num + 1) .tail - if (vecNorm != 0.0f) { - topResults = topResults.map { case (word, cosVal) => - (word, cosVal / vecNorm) - } - } - topResults.toArray + .toArray } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 4fcf417d5f82efcf5452fd42e1bdbac55c378ca0..6d699440f2f2e06c595160b7382e723c6e3ae7c6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -108,5 +108,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("test similarity for word vectors with large values is not Infinity or NaN") { + val vecA = Array(-4.331467827487745E21, -5.26707742075006E21, + 5.63551690626524E21, 2.833692188614257E21, -1.9688159903619345E21, -4.933950659913092E21, + -2.7401535502536787E21, -1.418671793782632E20).map(_.toFloat) + val vecB = Array(-3.9850175451103232E16, -3.4829783883841536E16, + 9.421469251534848E15, 4.4069684466679808E16, 7.20936298872832E15, -4.2883302830374912E16, + -3.605579947835392E16, -2.8151294422155264E16).map(_.toFloat) + val vecC = Array(-1.9227381025734656E16, -3.907009342603264E16, + 2.110207626838016E15, -4.8770066610651136E16, -1.9734964555743232E16, -3.2206001247617024E16, + 2.7725358220443648E16, 3.1618718156980224E16).map(_.toFloat) + val wordMapIn = Map( + ("A", vecA), + ("B", vecB), + ("C", vecC) + ) + + val model = new Word2VecModel(wordMapIn) + model.findSynonyms("A", 5).foreach { pair => + assert(!(pair._2.isInfinite || pair._2.isNaN)) + } + } } diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 610d167f3ad08b8198cd56ba555305b797836067..1b059a719913dd26167a103d0ca8978fd3c9fd06 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2186,13 +2186,14 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has | c|[-0.3794820010662...| +----+--------------------+ ... - >>> model.findSynonyms("a", 2).show() - +----+-------------------+ - |word| similarity| - +----+-------------------+ - | b| 0.2505344027513247| - | c|-0.6980510075367647| - +----+-------------------+ + >>> from pyspark.sql.functions import format_number as fmt + >>> model.findSynonyms("a", 2).select("word", fmt("similarity", 5).alias("similarity")).show() + +----+----------+ + |word|similarity| + +----+----------+ + | b| 0.25053| + | c| -0.69805| + +----+----------+ ... >>> model.transform(doc).head().model DenseVector([0.5524, -0.4995, -0.3599, 0.0241, 0.3461])