Skip to content
Snippets Groups Projects
Commit 89cda69e authored by Feynman Liang's avatar Feynman Liang Committed by Joseph K. Bradley
Browse files

[SPARK-9454] Change LDASuite tests to use vector comparisons

jkbradley Changes the current hacky string-comparison for vector compares.

Author: Feynman Liang <fliang@databricks.com>

Closes #7775 from feynmanliang/SPARK-9454-ldasuite-vector-compare and squashes the following commits:

bd91a82 [Feynman Liang] Remove println
905c76e [Feynman Liang] Fix string compare in distributed EM
2f24c13 [Feynman Liang] Improve LDASuite tests
parent 1abf7dc1
No related branches found
No related tags found
No related merge requests found
......@@ -83,21 +83,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.topicsMatrix === localModel.topicsMatrix)
// Check: topic summaries
// The odd decimal formatting and sorting is a hack to do a robust comparison.
val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) =>
// cut values to 3 digits after the decimal place
terms.zip(termWeights).map { case (term, weight) =>
("%.3f".format(weight).toDouble, term.toInt)
}
}.sortBy(_.mkString(""))
val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) =>
// cut values to 3 digits after the decimal place
terms.zip(termWeights).map { case (term, weight) =>
("%.3f".format(weight).toDouble, term.toInt)
}
}.sortBy(_.mkString(""))
roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) =>
assert(t1 === t2)
val topicSummary = model.describeTopics().map { case (terms, termWeights) =>
Vectors.sparse(tinyVocabSize, terms, termWeights)
}.sortBy(_.toString)
val localTopicSummary = localModel.describeTopics().map { case (terms, termWeights) =>
Vectors.sparse(tinyVocabSize, terms, termWeights)
}.sortBy(_.toString)
topicSummary.zip(localTopicSummary).foreach { case (topics, topicsLocal) =>
assert(topics ~== topicsLocal absTol 0.01)
}
// Check: per-doc topic distributions
......@@ -197,10 +190,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
// verify the result, Note this generate the identical result as
// [[https://github.com/Blei-Lab/onlineldavb]]
val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1)
assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2)
val topic1: Vector = Vectors.fromBreeze(op.getLambda(0, ::).t)
val topic2: Vector = Vectors.fromBreeze(op.getLambda(1, ::).t)
val expectedTopic1 = Vectors.dense(1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950)
val expectedTopic2 = Vectors.dense(0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050)
assert(topic1 ~== expectedTopic1 absTol 0.01)
assert(topic2 ~== expectedTopic2 absTol 0.01)
}
test("OnlineLDAOptimizer with toy data") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment