Skip to content
Snippets Groups Projects
Commit a5257048 authored by Liu Xiang's avatar Liu Xiang Committed by Xiangrui Meng
Browse files

[SPARK-12765][ML][COUNTVECTORIZER] fix CountVectorizer.transform's lost transformSchema

https://issues.apache.org/jira/browse/SPARK-12765

Author: Liu Xiang <lxmtlab@gmail.com>

Closes #10720 from sloth2012/sloth.
parent b3546738
No related branches found
No related tags found
No related merge requests found
...@@ -210,6 +210,7 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin ...@@ -210,6 +210,7 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None
override def transform(dataset: DataFrame): DataFrame = { override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
if (broadcastDict.isEmpty) { if (broadcastDict.isEmpty) {
val dict = vocabulary.zipWithIndex.toMap val dict = vocabulary.zipWithIndex.toMap
broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict)) broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment