Skip to content
Snippets Groups Projects
Commit f0060b75 authored by Liquan Pei's avatar Liquan Pei Committed by Xiangrui Meng
Browse files

[MLlib] Correctly set vectorSize and alpha

mengxr
Correctly set vectorSize and alpha in Word2Vec training.

Author: Liquan Pei <liquanpei@gmail.com>

Closes #1900 from Ishiihara/Word2Vec-bugfix and squashes the following commits:

85f64f2 [Liquan Pei] correctly set vectorSize and alpha
parent 9038d94e
No related branches found
No related tags found
No related merge requests found
...@@ -119,7 +119,6 @@ class Word2Vec extends Serializable with Logging { ...@@ -119,7 +119,6 @@ class Word2Vec extends Serializable with Logging {
private val MAX_EXP = 6 private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40 private val MAX_CODE_LENGTH = 40
private val MAX_SENTENCE_LENGTH = 1000 private val MAX_SENTENCE_LENGTH = 1000
private val layer1Size = vectorSize
/** context words from [-window, window] */ /** context words from [-window, window] */
private val window = 5 private val window = 5
...@@ -131,7 +130,6 @@ class Word2Vec extends Serializable with Logging { ...@@ -131,7 +130,6 @@ class Word2Vec extends Serializable with Logging {
private var vocabSize = 0 private var vocabSize = 0
private var vocab: Array[VocabWord] = null private var vocab: Array[VocabWord] = null
private var vocabHash = mutable.HashMap.empty[String, Int] private var vocabHash = mutable.HashMap.empty[String, Int]
private var alpha = startingAlpha
private def learnVocab(words: RDD[String]): Unit = { private def learnVocab(words: RDD[String]): Unit = {
vocab = words.map(w => (w, 1)) vocab = words.map(w => (w, 1))
...@@ -287,9 +285,10 @@ class Word2Vec extends Serializable with Logging { ...@@ -287,9 +285,10 @@ class Word2Vec extends Serializable with Logging {
val newSentences = sentences.repartition(numPartitions).cache() val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed) val initRandom = new XORShiftRandom(seed)
var syn0Global = var syn0Global =
Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size) Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
var syn1Global = new Array[Float](vocabSize * layer1Size) var syn1Global = new Array[Float](vocabSize * vectorSize)
var alpha = startingAlpha
for (k <- 1 to numIterations) { for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
...@@ -317,24 +316,24 @@ class Word2Vec extends Serializable with Logging { ...@@ -317,24 +316,24 @@ class Word2Vec extends Serializable with Logging {
val c = pos - window + a val c = pos - window + a
if (c >= 0 && c < sentence.size) { if (c >= 0 && c < sentence.size) {
val lastWord = sentence(c) val lastWord = sentence(c)
val l1 = lastWord * layer1Size val l1 = lastWord * vectorSize
val neu1e = new Array[Float](layer1Size) val neu1e = new Array[Float](vectorSize)
// Hierarchical softmax // Hierarchical softmax
var d = 0 var d = 0
while (d < bcVocab.value(word).codeLen) { while (d < bcVocab.value(word).codeLen) {
val l2 = bcVocab.value(word).point(d) * layer1Size val l2 = bcVocab.value(word).point(d) * vectorSize
// Propagate hidden -> output // Propagate hidden -> output
var f = blas.sdot(layer1Size, syn0, l1, 1, syn1, l2, 1) var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1)
if (f > -MAX_EXP && f < MAX_EXP) { if (f > -MAX_EXP && f < MAX_EXP) {
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
f = expTable.value(ind) f = expTable.value(ind)
val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
blas.saxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1)
blas.saxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1)
} }
d += 1 d += 1
} }
blas.saxpy(layer1Size, 1.0f, neu1e, 0, 1, syn0, l1, 1) blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1)
} }
} }
a += 1 a += 1
...@@ -365,8 +364,8 @@ class Word2Vec extends Serializable with Logging { ...@@ -365,8 +364,8 @@ class Word2Vec extends Serializable with Logging {
var i = 0 var i = 0
while (i < vocabSize) { while (i < vocabSize) {
val word = bcVocab.value(i).word val word = bcVocab.value(i).word
val vector = new Array[Float](layer1Size) val vector = new Array[Float](vectorSize)
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize)
word2VecMap += word -> vector word2VecMap += word -> vector
i += 1 i += 1
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment