Skip to content
Snippets Groups Projects
Commit e9c9ae22 authored by Antonio Murgia's avatar Antonio Murgia Committed by Sean Owen
Browse files

[SPARK-11994][MLLIB] Word2VecModel load and save cause SparkException when...

[SPARK-11994][MLLIB] Word2VecModel load and save cause SparkException when model is bigger than spark.kryoserializer.buffer.max

Author: Antonio Murgia <antonio.murgia2@studio.unibo.it>

Closes #9989 from tmnd1991/SPARK-11932.
parent ee94b70c
No related branches found
No related tags found
No related merge requests found
...@@ -604,13 +604,21 @@ object Word2VecModel extends Loader[Word2VecModel] { ...@@ -604,13 +604,21 @@ object Word2VecModel extends Loader[Word2VecModel] {
val vectorSize = model.values.head.size val vectorSize = model.values.head.size
val numWords = model.size val numWords = model.size
val metadata = compact(render val metadata = compact(render(
(("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ ("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~
("vectorSize" -> vectorSize) ~ ("numWords" -> numWords))) ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
// We want to partition the model in partitions of size 32MB
val partitionSize = (1L << 25)
// We calculate the approximate size of the model
// We only calculate the array size, not considering
// the string size, the formula is:
// floatSize * numWords * vectorSize
val approxSize = 4L * numWords * vectorSize
val nPartitions = ((approxSize / partitionSize) + 1).toInt
val dataArray = model.toSeq.map { case (w, v) => Data(w, v) } val dataArray = model.toSeq.map { case (w, v) => Data(w, v) }
sc.parallelize(dataArray.toSeq, 1).toDF().write.parquet(Loader.dataPath(path)) sc.parallelize(dataArray.toSeq, nPartitions).toDF().write.parquet(Loader.dataPath(path))
} }
} }
......
...@@ -92,4 +92,23 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -92,4 +92,23 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
} }
} }
test("big model load / save") {
// create a model bigger than 32MB since 9000 * 1000 * 4 > 2^25
val word2VecMap = Map((0 to 9000).map(i => s"$i" -> Array.fill(1000)(0.1f)): _*)
val model = new Word2VecModel(word2VecMap)
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
try {
model.save(sc, path)
val sameModel = Word2VecModel.load(sc, path)
assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq))
} finally {
Utils.deleteRecursively(tempDir)
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment