Skip to content
Snippets Groups Projects
Commit 22b9a874 authored by Holden Karau's avatar Holden Karau Committed by Sean Owen
Browse files

[SPARK-10299][ML] word2vec should allow users to specify the window size

Currently word2vec has the window hard coded at 5, some users may want different sizes (for example if using on n-gram input or similar). User request comes from http://stackoverflow.com/questions/32231975/spark-word2vec-window-size .

Author: Holden Karau <holden@us.ibm.com>
Author: Holden Karau <holden@pigscanfly.ca>

Closes #8513 from holdenk/SPARK-10299-word2vec-should-allow-users-to-specify-the-window-size.
parent 6e1c55ea
No related branches found
No related tags found
No related merge requests found
...@@ -49,6 +49,17 @@ private[feature] trait Word2VecBase extends Params ...@@ -49,6 +49,17 @@ private[feature] trait Word2VecBase extends Params
/** @group getParam */ /** @group getParam */
def getVectorSize: Int = $(vectorSize) def getVectorSize: Int = $(vectorSize)
/**
* The window size (context words from [-window, window]) default 5.
* @group expertParam
*/
final val windowSize = new IntParam(
this, "windowSize", "the window size (context words from [-window, window])")
setDefault(windowSize -> 5)
/** @group expertGetParam */
def getWindowSize: Int = $(windowSize)
/** /**
* Number of partitions for sentences of words. * Number of partitions for sentences of words.
* Default: 1 * Default: 1
...@@ -106,6 +117,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] ...@@ -106,6 +117,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
/** @group setParam */ /** @group setParam */
def setVectorSize(value: Int): this.type = set(vectorSize, value) def setVectorSize(value: Int): this.type = set(vectorSize, value)
/** @group expertSetParam */
def setWindowSize(value: Int): this.type = set(windowSize, value)
/** @group setParam */ /** @group setParam */
def setStepSize(value: Double): this.type = set(stepSize, value) def setStepSize(value: Double): this.type = set(stepSize, value)
...@@ -131,6 +145,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] ...@@ -131,6 +145,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
.setNumPartitions($(numPartitions)) .setNumPartitions($(numPartitions))
.setSeed($(seed)) .setSeed($(seed))
.setVectorSize($(vectorSize)) .setVectorSize($(vectorSize))
.setWindowSize($(windowSize))
.fit(input) .fit(input)
copyValues(new Word2VecModel(uid, wordVectors).setParent(this)) copyValues(new Word2VecModel(uid, wordVectors).setParent(this))
} }
......
...@@ -125,6 +125,15 @@ class Word2Vec extends Serializable with Logging { ...@@ -125,6 +125,15 @@ class Word2Vec extends Serializable with Logging {
this this
} }
/**
* Sets the window of words (default: 5)
*/
@Since("1.6.0")
def setWindowSize(window: Int): this.type = {
this.window = window
this
}
/** /**
* Sets minCount, the minimum number of times a token must appear to be included in the word2vec * Sets minCount, the minimum number of times a token must appear to be included in the word2vec
* model's vocabulary (default: 5). * model's vocabulary (default: 5).
...@@ -141,7 +150,7 @@ class Word2Vec extends Serializable with Logging { ...@@ -141,7 +150,7 @@ class Word2Vec extends Serializable with Logging {
private val MAX_SENTENCE_LENGTH = 1000 private val MAX_SENTENCE_LENGTH = 1000
/** context words from [-window, window] */ /** context words from [-window, window] */
private val window = 5 private var window = 5
private var trainWordsCount = 0 private var trainWordsCount = 0
private var vocabSize = 0 private var vocabSize = 0
......
...@@ -35,7 +35,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul ...@@ -35,7 +35,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
} }
test("Word2Vec") { test("Word2Vec") {
val sqlContext = new SQLContext(sc)
val sqlContext = this.sqlContext
import sqlContext.implicits._ import sqlContext.implicits._
val sentence = "a b " * 100 + "a c " * 10 val sentence = "a b " * 100 + "a c " * 10
...@@ -77,7 +78,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul ...@@ -77,7 +78,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("getVectors") { test("getVectors") {
val sqlContext = new SQLContext(sc) val sqlContext = this.sqlContext
import sqlContext.implicits._ import sqlContext.implicits._
val sentence = "a b " * 100 + "a c " * 10 val sentence = "a b " * 100 + "a c " * 10
...@@ -118,7 +119,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul ...@@ -118,7 +119,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("findSynonyms") { test("findSynonyms") {
val sqlContext = new SQLContext(sc) val sqlContext = this.sqlContext
import sqlContext.implicits._ import sqlContext.implicits._
val sentence = "a b " * 100 + "a c " * 10 val sentence = "a b " * 100 + "a c " * 10
...@@ -141,7 +142,43 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul ...@@ -141,7 +142,43 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
expectedSimilarity.zip(similarity).map { expectedSimilarity.zip(similarity).map {
case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5)
} }
}
test("window size") {
val sqlContext = this.sqlContext
import sqlContext.implicits._
val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
val docDF = doc.zip(doc).toDF("text", "alsotext")
val model = new Word2Vec()
.setVectorSize(3)
.setWindowSize(2)
.setInputCol("text")
.setOutputCol("result")
.setSeed(42L)
.fit(docDF)
val (synonyms, similarity) = model.findSynonyms("a", 6).map {
case Row(w: String, sim: Double) => (w, sim)
}.collect().unzip
// Increase the window size
val biggerModel = new Word2Vec()
.setVectorSize(3)
.setInputCol("text")
.setOutputCol("result")
.setSeed(42L)
.setWindowSize(10)
.fit(docDF)
val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map {
case Row(w: String, sim: Double) => (w, sim)
}.collect().unzip
// The similarity score should be very different with the larger window
assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5)
} }
test("Word2Vec read/write") { test("Word2Vec read/write") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment