diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index dee898827f30fbe1bebd3fc19949ccfe6b1f623f..3241ebeb22c4256f99ce29bac4428a1f74e26ab7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -76,6 +76,18 @@ class Word2Vec extends Serializable with Logging { private var numIterations = 1 private var seed = Utils.random.nextLong() private var minCount = 5 + private var maxSentenceLength = 1000 + + /** + * Sets the maximum length (in words) of each sentence in the input data. + * Any sentence longer than this threshold will be divided into chunks of + * up to `maxSentenceLength` size (default: 1000) + */ + @Since("2.0.0") + def setMaxSentenceLength(maxSentenceLength: Int): this.type = { + this.maxSentenceLength = maxSentenceLength + this + } /** * Sets vector size (default: 100). @@ -146,7 +158,6 @@ class Word2Vec extends Serializable with Logging { private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 - private val MAX_SENTENCE_LENGTH = 1000 /** context words from [-window, window] */ private var window = 5 @@ -156,7 +167,9 @@ class Word2Vec extends Serializable with Logging { @transient private var vocab: Array[VocabWord] = null @transient private var vocabHash = mutable.HashMap.empty[String, Int] - private def learnVocab(words: RDD[String]): Unit = { + private def learnVocab[S <: Iterable[String]](dataset: RDD[S]): Unit = { + val words = dataset.flatMap(x => x) + vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) .filter(_._2 >= minCount) @@ -272,15 +285,14 @@ class Word2Vec extends Serializable with Logging { /** * Computes the vector representation of each word in vocabulary. - * @param dataset an RDD of words + * @param dataset an RDD of sentences, + * each sentence is expressed as an iterable collection of words * @return a Word2VecModel */ @Since("1.1.0") def fit[S <: Iterable[String]](dataset: RDD[S]): Word2VecModel = { - val words = dataset.flatMap(x => x) - - learnVocab(words) + learnVocab(dataset) createBinaryTree() @@ -289,25 +301,15 @@ class Word2Vec extends Serializable with Logging { val expTable = sc.broadcast(createExpTable()) val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) - - val sentences: RDD[Array[Int]] = words.mapPartitions { iter => - new Iterator[Array[Int]] { - def hasNext: Boolean = iter.hasNext - - def next(): Array[Int] = { - val sentence = ArrayBuilder.make[Int] - var sentenceLength = 0 - while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { - val word = bcVocabHash.value.get(iter.next()) - word match { - case Some(w) => - sentence += w - sentenceLength += 1 - case None => - } - } - sentence.result() - } + // each partition is a collection of sentences, + // will be translated into arrays of Index integer + val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter => + // Each sentence will map to 0 or more Array[Int] + sentenceIter.flatMap { sentence => + // Sentence of words, some of which map to a word index + val wordIndexes = sentence.flatMap(bcVocabHash.value.get) + // break wordIndexes into trunks of maxSentenceLength when has more + wordIndexes.grouped(maxSentenceLength).map(_.toArray) } } @@ -477,15 +479,6 @@ class Word2VecModel private[spark] ( this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model)) } - private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { - require(v1.length == v2.length, "Vectors should have the same length") - val n = v1.length - val norm1 = blas.snrm2(n, v1, 1) - val norm2 = blas.snrm2(n, v2, 1) - if (norm1 == 0 || norm2 == 0) return 0.0 - blas.sdot(n, v1, 1, v2, 1) / norm1 / norm2 - } - override protected def formatVersion = "1.0" @Since("1.4.0") @@ -542,6 +535,7 @@ class Word2VecModel private[spark] ( // Need not divide with the norm of the given vector since it is constant. val cosVec = cosineVec.map(_.toDouble) var ind = 0 + val vecNorm = blas.snrm2(vectorSize, fVector, 1) while (ind < numWords) { val norm = wordVecNorms(ind) if (norm == 0.0) { @@ -551,12 +545,17 @@ class Word2VecModel private[spark] ( } ind += 1 } - wordList.zip(cosVec) + var topResults = wordList.zip(cosVec) .toSeq - .sortBy(- _._2) + .sortBy(-_._2) .take(num + 1) .tail - .toArray + if (vecNorm != 0.0f) { + topResults = topResults.map { case (word, cosVal) => + (word, cosVal / vecNorm) + } + } + topResults.toArray } /** @@ -568,6 +567,7 @@ class Word2VecModel private[spark] ( (word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize)) } } + } @Since("1.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index a73b565125668775c0585d144fed12dd48e1b9e7..f094c550e545a7a1e39db4e4c736e446b66d7a44 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -133,7 +133,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setSeed(42L) .fit(docDF) - val expectedSimilarity = Array(0.18032623242822343, -0.5717976464798823) + val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078) val (synonyms, similarity) = model.findSynonyms("a", 2).map { case Row(w: String, sim: Double) => (w, sim) }.collect().unzip diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index d017a231886cbb6635daf9edefa0a4399c85467d..464c9446f2f3932b02180ecbd11dcf46d1bba76f 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1836,12 +1836,12 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has +----+--------------------+ ... >>> model.findSynonyms("a", 2).show() - +----+--------------------+ - |word| similarity| - +----+--------------------+ - | b| 0.16782984556103436| - | c|-0.46761559092107646| - +----+--------------------+ + +----+-------------------+ + |word| similarity| + +----+-------------------+ + | b| 0.2505344027513247| + | c|-0.6980510075367647| + +----+-------------------+ ... >>> model.transform(doc).head().model DenseVector([0.5524, -0.4995, -0.3599, 0.0241, 0.3461])