Skip to content
Snippets Groups Projects
Commit 0e00f12d authored by MechCoder's avatar MechCoder Committed by Xiangrui Meng
Browse files

[SPARK-5692] [MLlib] Word2Vec save/load

Word2Vec model now supports saving and loading.

a] The Metadata stored in JSON format consists of "version", "classname", "vectorSize" and "numWords"
b] The data stored in Parquet file format consists of an Array of rows with each row consisting of 2 columns, first being the word: String and the second, an Array of Floats.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #5291 from MechCoder/spark-5692 and squashes the following commits:

1142f3a [MechCoder] Add numWords to metaData
bfe4c39 [MechCoder] [SPARK-5692] Word2Vec save/load
parent 2036bc59
No related branches found
No related tags found
No related merge requests found
......@@ -25,14 +25,21 @@ import scala.collection.mutable.ArrayBuilder
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.Logging
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.sql.{SQLContext, Row}
/**
* Entry in vocabulary
......@@ -422,7 +429,7 @@ class Word2Vec extends Serializable with Logging {
*/
@Experimental
class Word2VecModel private[mllib] (
private val model: Map[String, Array[Float]]) extends Serializable {
private val model: Map[String, Array[Float]]) extends Serializable with Saveable {
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
......@@ -432,7 +439,13 @@ class Word2VecModel private[mllib] (
if (norm1 == 0 || norm2 == 0) return 0.0
blas.sdot(n, v1, 1, v2,1) / norm1 / norm2
}
override protected def formatVersion = "1.0"
def save(sc: SparkContext, path: String): Unit = {
Word2VecModel.SaveLoadV1_0.save(sc, path, model)
}
/**
* Transforms a word to its vector representation
* @param word a word
......@@ -475,7 +488,7 @@ class Word2VecModel private[mllib] (
.tail
.toArray
}
/**
* Returns a map of words to their vector representations.
*/
......@@ -483,3 +496,71 @@ class Word2VecModel private[mllib] (
model
}
}
@Experimental
object Word2VecModel extends Loader[Word2VecModel] {
private object SaveLoadV1_0 {
val formatVersionV1_0 = "1.0"
val classNameV1_0 = "org.apache.spark.mllib.feature.Word2VecModel"
case class Data(word: String, vector: Array[Float])
def load(sc: SparkContext, path: String): Word2VecModel = {
val dataPath = Loader.dataPath(path)
val sqlContext = new SQLContext(sc)
val dataFrame = sqlContext.parquetFile(dataPath)
val dataArray = dataFrame.select("word", "vector").collect()
// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[Data](dataFrame.schema)
val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap
new Word2VecModel(word2VecMap)
}
def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]) = {
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val vectorSize = model.values.head.size
val numWords = model.size
val metadata = compact(render
(("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~
("vectorSize" -> vectorSize) ~ ("numWords" -> numWords)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val dataArray = model.toSeq.map { case (w, v) => Data(w, v) }
sc.parallelize(dataArray.toSeq, 1).toDF().saveAsParquetFile(Loader.dataPath(path))
}
}
override def load(sc: SparkContext, path: String): Word2VecModel = {
val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
implicit val formats = DefaultFormats
val expectedVectorSize = (metadata \ "vectorSize").extract[Int]
val expectedNumWords = (metadata \ "numWords").extract[Int]
val classNameV1_0 = SaveLoadV1_0.classNameV1_0
(loadedClassName, loadedVersion) match {
case (classNameV1_0, "1.0") =>
val model = SaveLoadV1_0.load(sc, path)
val vectorSize = model.getVectors.values.head.size
val numWords = model.getVectors.size
require(expectedVectorSize == vectorSize,
s"Word2VecModel requires each word to be mapped to a vector of size " +
s"$expectedVectorSize, got vector of size $vectorSize")
require(expectedNumWords == numWords,
s"Word2VecModel requires $expectedNumWords words, but got $numWords")
model
case _ => throw new Exception(
s"Word2VecModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $loadedVersion). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}
}
......@@ -21,6 +21,9 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
// TODO: add more tests
......@@ -51,4 +54,27 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
assert(syms(0)._1 == "taiwan")
assert(syms(1)._1 == "japan")
}
test("model load / save") {
val word2VecMap = Map(
("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
)
val model = new Word2VecModel(word2VecMap)
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
try {
model.save(sc, path)
val sameModel = Word2VecModel.load(sc, path)
assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq))
} finally {
Utils.deleteRecursively(tempDir)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment