Skip to content
Snippets Groups Projects
Commit 13675c74 authored by MechCoder's avatar MechCoder Committed by Joseph K. Bradley
Browse files

[SPARK-8874] [ML] Add missing methods in Word2Vec

Add missing methods

1. getVectors
2. findSynonyms

to W2Vec scala and python API

mengxr

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #7263 from MechCoder/missing_methods_w2vec and squashes the following commits:

149d5ca [MechCoder] minor doc
69d91b7 [MechCoder] [SPARK-8874] [ML] Add missing methods in Word2Vec
parent a2409d1c
No related branches found
No related tags found
No related merge requests found
......@@ -18,15 +18,17 @@
package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types._
/**
......@@ -146,6 +148,40 @@ class Word2VecModel private[ml] (
wordVectors: feature.Word2VecModel)
extends Model[Word2VecModel] with Word2VecBase {
/**
* Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
* and the vector the DenseVector that it is mapped to.
*/
val getVectors: DataFrame = {
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
val wordVec = wordVectors.getVectors.mapValues(vec => Vectors.dense(vec.map(_.toDouble)))
sc.parallelize(wordVec.toSeq).toDF("word", "vector")
}
/**
* Find "num" number of words closest in similarity to the given word.
* Returns a dataframe with the words and the cosine similarities between the
* synonyms and the given word.
*/
def findSynonyms(word: String, num: Int): DataFrame = {
findSynonyms(wordVectors.transform(word), num)
}
/**
* Find "num" number of words closest to similarity to the given vector representation
* of the word. Returns a dataframe with the words and the cosine similarities between the
* synonyms and the given word vector.
*/
def findSynonyms(word: Vector, num: Int): DataFrame = {
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
sc.parallelize(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
}
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
......
......@@ -67,5 +67,67 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.")
}
}
test("getVectors") {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val sentence = "a b " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
val codes = Map(
"a" -> Array(-0.2811822295188904, -0.6356269121170044, -0.3020961284637451),
"b" -> Array(1.0309048891067505, -1.29472815990448, 0.22276712954044342),
"c" -> Array(-0.08456747233867645, 0.5137411952018738, 0.11731560528278351)
)
val expectedVectors = codes.toSeq.sortBy(_._1).map { case (w, v) => Vectors.dense(v) }
val docDF = doc.zip(doc).toDF("text", "alsotext")
val model = new Word2Vec()
.setVectorSize(3)
.setInputCol("text")
.setOutputCol("result")
.setSeed(42L)
.fit(docDF)
val realVectors = model.getVectors.sort("word").select("vector").map {
case Row(v: Vector) => v
}.collect()
realVectors.zip(expectedVectors).foreach {
case (real, expected) =>
assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.")
}
}
test("findSynonyms") {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val sentence = "a b " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
val docDF = doc.zip(doc).toDF("text", "alsotext")
val model = new Word2Vec()
.setVectorSize(3)
.setInputCol("text")
.setOutputCol("result")
.setSeed(42L)
.fit(docDF)
val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644)
val (synonyms, similarity) = model.findSynonyms("a", 2).map {
case Row(w: String, sim: Double) => (w, sim)
}.collect().unzip
assert(synonyms.toArray === Array("b", "c"))
expectedSimilarity.zip(similarity).map {
case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment