diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 7960f3cab576f5daaea6e8367e6e64ab68c32007..d25a7cd5b439d591312b154462142f03c44e3296 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -71,7 +71,8 @@ class Word2Vec extends Serializable with Logging { private var numPartitions = 1 private var numIterations = 1 private var seed = Utils.random.nextLong() - + private var minCount = 5 + /** * Sets vector size (default: 100). */ @@ -114,6 +115,15 @@ class Word2Vec extends Serializable with Logging { this } + /** + * Sets minCount, the minimum number of times a token must appear to be included in the word2vec + * model's vocabulary (default: 5). + */ + def setMinCount(minCount: Int): this.type = { + this.minCount = minCount + this + } + private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 @@ -122,9 +132,6 @@ class Word2Vec extends Serializable with Logging { /** context words from [-window, window] */ private val window = 5 - /** minimum frequency to consider a vocabulary word */ - private val minCount = 5 - private var trainWordsCount = 0 private var vocabSize = 0 private var vocab: Array[VocabWord] = null diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 47d1a76fa361df2fedbef46fe7bc52dfa7192e13..01f3f90577142791cffe99e3874651d4e049acf2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -268,7 +268,7 @@ object Vectors { * @param p norm. * @return norm in L^p^ space. */ - private[spark] def norm(vector: Vector, p: Double): Double = { + def norm(vector: Vector, p: Double): Double = { require(p >= 1.0) val values = vector match { case dv: DenseVector => dv.values