From 354f4582b637fa25d3892ec2b12869db50ed83c9 Mon Sep 17 00:00:00 2001
From: Yuhao Yang <hhbyyh@gmail.com>
Date: Tue, 18 Aug 2015 11:00:09 -0700
Subject: [PATCH] [SPARK-9028] [ML] Add CountVectorizer as an estimator to
 generate CountVectorizerModel

jira: https://issues.apache.org/jira/browse/SPARK-9028

Add an estimator for CountVectorizerModel. The estimator will extract a vocabulary from document collections according to the term frequency.

I changed the meaning of minCount as a filter across the corpus. This aligns with Word2Vec and the similar parameter in SKlearn.

Author: Yuhao Yang <hhbyyh@gmail.com>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #7388 from hhbyyh/cvEstimator.
---
 .../spark/ml/feature/CountVectorizer.scala    | 235 ++++++++++++++++++
 .../ml/feature/CountVectorizerModel.scala     |  82 ------
 .../ml/feature/CountVectorizerSuite.scala     | 167 +++++++++++++
 .../ml/feature/CountVectorizorSuite.scala     |  73 ------
 4 files changed, 402 insertions(+), 155 deletions(-)
 create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
 delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala
 create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
 delete mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
new file mode 100644
index 0000000000..49028e4b85
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.util.collection.OpenHashMap
+
+/**
+ * Params for [[CountVectorizer]] and [[CountVectorizerModel]].
+ */
+private[feature] trait CountVectorizerParams extends Params with HasInputCol with HasOutputCol {
+
+  /**
+   * Max size of the vocabulary.
+   * CountVectorizer will build a vocabulary that only considers the top
+   * vocabSize terms ordered by term frequency across the corpus.
+   *
+   * Default: 2^18^
+   * @group param
+   */
+  val vocabSize: IntParam =
+    new IntParam(this, "vocabSize", "max size of the vocabulary", ParamValidators.gt(0))
+
+  /** @group getParam */
+  def getVocabSize: Int = $(vocabSize)
+
+  /**
+   * Specifies the minimum number of different documents a term must appear in to be included
+   * in the vocabulary.
+   * If this is an integer >= 1, this specifies the number of documents the term must appear in;
+   * if this is a double in [0,1), then this specifies the fraction of documents.
+   *
+   * Default: 1
+   * @group param
+   */
+  val minDF: DoubleParam = new DoubleParam(this, "minDF", "Specifies the minimum number of" +
+    " different documents a term must appear in to be included in the vocabulary." +
+    " If this is an integer >= 1, this specifies the number of documents the term must" +
+    " appear in; if this is a double in [0,1), then this specifies the fraction of documents.",
+    ParamValidators.gtEq(0.0))
+
+  /** @group getParam */
+  def getMinDF: Double = $(minDF)
+
+  /** Validates and transforms the input schema. */
+  protected def validateAndTransformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
+    SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
+  }
+
+  /**
+   * Filter to ignore rare words in a document. For each document, terms with
+   * frequency/count less than the given threshold are ignored.
+   * If this is an integer >= 1, then this specifies a count (of times the term must appear
+   * in the document);
+   * if this is a double in [0,1), then this specifies a fraction (out of the document's token
+   * count).
+   *
+   * Note that the parameter is only used in transform of [[CountVectorizerModel]] and does not
+   * affect fitting.
+   *
+   * Default: 1
+   * @group param
+   */
+  val minTF: DoubleParam = new DoubleParam(this, "minTF", "Filter to ignore rare words in" +
+    " a document. For each document, terms with frequency/count less than the given threshold are" +
+    " ignored. If this is an integer >= 1, then this specifies a count (of times the term must" +
+    " appear in the document); if this is a double in [0,1), then this specifies a fraction (out" +
+    " of the document's token count). Note that the parameter is only used in transform of" +
+    " CountVectorizerModel and does not affect fitting.", ParamValidators.gtEq(0.0))
+
+  setDefault(minTF -> 1)
+
+  /** @group getParam */
+  def getMinTF: Double = $(minTF)
+}
+
+/**
+ * :: Experimental ::
+ * Extracts a vocabulary from document collections and generates a [[CountVectorizerModel]].
+ */
+@Experimental
+class CountVectorizer(override val uid: String)
+  extends Estimator[CountVectorizerModel] with CountVectorizerParams {
+
+  def this() = this(Identifiable.randomUID("cntVec"))
+
+  /** @group setParam */
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  /** @group setParam */
+  def setVocabSize(value: Int): this.type = set(vocabSize, value)
+
+  /** @group setParam */
+  def setMinDF(value: Double): this.type = set(minDF, value)
+
+  /** @group setParam */
+  def setMinTF(value: Double): this.type = set(minTF, value)
+
+  setDefault(vocabSize -> (1 << 18), minDF -> 1)
+
+  override def fit(dataset: DataFrame): CountVectorizerModel = {
+    transformSchema(dataset.schema, logging = true)
+    val vocSize = $(vocabSize)
+    val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0))
+    val minDf = if ($(minDF) >= 1.0) {
+      $(minDF)
+    } else {
+      $(minDF) * input.cache().count()
+    }
+    val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) =>
+      val wc = new OpenHashMap[String, Long]
+      tokens.foreach { w =>
+        wc.changeValue(w, 1L, _ + 1L)
+      }
+      wc.map { case (word, count) => (word, (count, 1)) }
+    }.reduceByKey { case ((wc1, df1), (wc2, df2)) =>
+      (wc1 + wc2, df1 + df2)
+    }.filter { case (word, (wc, df)) =>
+      df >= minDf
+    }.map { case (word, (count, dfCount)) =>
+      (word, count)
+    }.cache()
+    val fullVocabSize = wordCounts.count()
+    val vocab: Array[String] = {
+      val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) {
+        // Use all terms
+        wordCounts.collect().sortBy(-_._2)
+      } else {
+        // Sort terms to select vocab
+        wordCounts.sortBy(_._2, ascending = false).take(vocSize)
+      }
+      tmpSortedWC.map(_._1)
+    }
+
+    require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.")
+    copyValues(new CountVectorizerModel(uid, vocab).setParent(this))
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema)
+  }
+
+  override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra)
+}
+
+/**
+ * :: Experimental ::
+ * Converts a text document to a sparse vector of token counts.
+ * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted.
+ */
+@Experimental
+class CountVectorizerModel(override val uid: String, val vocabulary: Array[String])
+  extends Model[CountVectorizerModel] with CountVectorizerParams {
+
+  def this(vocabulary: Array[String]) = {
+    this(Identifiable.randomUID("cntVecModel"), vocabulary)
+    set(vocabSize, vocabulary.length)
+  }
+
+  /** @group setParam */
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
+  /** @group setParam */
+  def setMinTF(value: Double): this.type = set(minTF, value)
+
+  /** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */
+  private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None
+
+  override def transform(dataset: DataFrame): DataFrame = {
+    if (broadcastDict.isEmpty) {
+      val dict = vocabulary.zipWithIndex.toMap
+      broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict))
+    }
+    val dictBr = broadcastDict.get
+    val minTf = $(minTF)
+    val vectorizer = udf { (document: Seq[String]) =>
+      val termCounts = new OpenHashMap[Int, Double]
+      var tokenCount = 0L
+      document.foreach { term =>
+        dictBr.value.get(term) match {
+          case Some(index) => termCounts.changeValue(index, 1.0, _ + 1.0)
+          case None => // ignore terms not in the vocabulary
+        }
+        tokenCount += 1
+      }
+      val effectiveMinTF = if (minTf >= 1.0) {
+        minTf
+      } else {
+        tokenCount * minTf
+      }
+      Vectors.sparse(dictBr.value.size, termCounts.filter(_._2 >= effectiveMinTF).toSeq)
+    }
+    dataset.withColumn($(outputCol), vectorizer(col($(inputCol))))
+  }
+
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema)
+  }
+
+  override def copy(extra: ParamMap): CountVectorizerModel = {
+    val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent)
+    copyValues(copied, extra)
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala
deleted file mode 100644
index 6b77de89a0..0000000000
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.ml.feature
-
-import scala.collection.mutable
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam}
-import org.apache.spark.ml.util.Identifiable
-import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector}
-import org.apache.spark.sql.types.{StringType, ArrayType, DataType}
-
-/**
- * :: Experimental ::
- * Converts a text document to a sparse vector of token counts.
- * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted.
- */
-@Experimental
-class CountVectorizerModel (override val uid: String, val vocabulary: Array[String])
-  extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] {
-
-  def this(vocabulary: Array[String]) =
-    this(Identifiable.randomUID("cntVec"), vocabulary)
-
-  /**
-   * Corpus-specific filter to ignore scarce words in a document. For each document, terms with
-   * frequency (count) less than the given threshold are ignored.
-   * Default: 1
-   * @group param
-   */
-  val minTermFreq: IntParam = new IntParam(this, "minTermFreq",
-    "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " +
-      "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1))
-
-  /** @group setParam */
-  def setMinTermFreq(value: Int): this.type = set(minTermFreq, value)
-
-  /** @group getParam */
-  def getMinTermFreq: Int = $(minTermFreq)
-
-  setDefault(minTermFreq -> 1)
-
-  override protected def createTransformFunc: Seq[String] => Vector = {
-    val dict = vocabulary.zipWithIndex.toMap
-    document =>
-      val termCounts = mutable.HashMap.empty[Int, Double]
-      document.foreach { term =>
-        dict.get(term) match {
-          case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0)
-          case None => // ignore terms not in the vocabulary
-        }
-      }
-      Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermFreq)).toSeq)
-  }
-
-  override protected def validateInputType(inputType: DataType): Unit = {
-    require(inputType.sameType(ArrayType(StringType)),
-      s"Input type must be ArrayType(StringType) but got $inputType.")
-  }
-
-  override protected def outputDataType: DataType = new VectorUDT()
-
-  override def copy(extra: ParamMap): CountVectorizerModel = {
-    val copied = new CountVectorizerModel(uid, vocabulary)
-    copyValues(copied, extra)
-  }
-}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
new file mode 100644
index 0000000000..e192fa4850
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.Row
+
+class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  test("params") {
+    ParamsSuite.checkParams(new CountVectorizerModel(Array("empty")))
+  }
+
+  private def split(s: String): Seq[String] = s.split("\\s+")
+
+  test("CountVectorizerModel common cases") {
+    val df = sqlContext.createDataFrame(Seq(
+      (0, split("a b c d"),
+        Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
+      (1, split("a b b c d  a"),
+        Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))),
+      (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))),
+      (3, split(""), Vectors.sparse(4, Seq())), // empty string
+      (4, split("a notInDict d"),
+        Vectors.sparse(4, Seq((0, 1.0), (3, 1.0))))  // with words not in vocabulary
+    )).toDF("id", "words", "expected")
+    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
+      .setInputCol("words")
+      .setOutputCol("features")
+    cv.transform(df).select("features", "expected").collect().foreach {
+      case Row(features: Vector, expected: Vector) =>
+        assert(features ~== expected absTol 1e-14)
+    }
+  }
+
+  test("CountVectorizer common cases") {
+    val df = sqlContext.createDataFrame(Seq(
+      (0, split("a b c d e"),
+        Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
+      (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
+      (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))),
+      (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
+    ).toDF("id", "words", "expected")
+    val cv = new CountVectorizer()
+      .setInputCol("words")
+      .setOutputCol("features")
+      .fit(df)
+    assert(cv.vocabulary === Array("a", "b", "c", "d", "e"))
+
+    cv.transform(df).select("features", "expected").collect().foreach {
+      case Row(features: Vector, expected: Vector) =>
+        assert(features ~== expected absTol 1e-14)
+    }
+  }
+
+  test("CountVectorizer vocabSize and minDF") {
+    val df = sqlContext.createDataFrame(Seq(
+      (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
+      (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
+      (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
+      (3, split("a"), Vectors.sparse(3, Seq((0, 1.0)))))
+    ).toDF("id", "words", "expected")
+    val cvModel = new CountVectorizer()
+      .setInputCol("words")
+      .setOutputCol("features")
+      .setVocabSize(3)  // limit vocab size to 3
+      .fit(df)
+    assert(cvModel.vocabulary === Array("a", "b", "c"))
+
+    // minDF: ignore terms with count less than 3
+    val cvModel2 = new CountVectorizer()
+      .setInputCol("words")
+      .setOutputCol("features")
+      .setMinDF(3)
+      .fit(df)
+    assert(cvModel2.vocabulary === Array("a", "b"))
+
+    cvModel2.transform(df).select("features", "expected").collect().foreach {
+      case Row(features: Vector, expected: Vector) =>
+        assert(features ~== expected absTol 1e-14)
+    }
+
+    // minDF: ignore terms with freq < 0.75
+    val cvModel3 = new CountVectorizer()
+      .setInputCol("words")
+      .setOutputCol("features")
+      .setMinDF(3.0 / df.count())
+      .fit(df)
+    assert(cvModel3.vocabulary === Array("a", "b"))
+
+    cvModel3.transform(df).select("features", "expected").collect().foreach {
+      case Row(features: Vector, expected: Vector) =>
+        assert(features ~== expected absTol 1e-14)
+    }
+  }
+
+  test("CountVectorizer throws exception when vocab is empty") {
+    intercept[IllegalArgumentException] {
+      val df = sqlContext.createDataFrame(Seq(
+        (0, split("a a b b c c")),
+        (1, split("aa bb cc")))
+      ).toDF("id", "words")
+      val cvModel = new CountVectorizer()
+        .setInputCol("words")
+        .setOutputCol("features")
+        .setVocabSize(3) // limit vocab size to 3
+        .setMinDF(3)
+        .fit(df)
+    }
+  }
+
+  test("CountVectorizerModel with minTF count") {
+    val df = sqlContext.createDataFrame(Seq(
+      (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
+      (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
+      (2, split("a"), Vectors.sparse(4, Seq())),
+      (3, split("e e e e e"), Vectors.sparse(4, Seq())))
+    ).toDF("id", "words", "expected")
+
+    // minTF: count
+    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
+      .setInputCol("words")
+      .setOutputCol("features")
+      .setMinTF(3)
+    cv.transform(df).select("features", "expected").collect().foreach {
+      case Row(features: Vector, expected: Vector) =>
+        assert(features ~== expected absTol 1e-14)
+    }
+  }
+
+  test("CountVectorizerModel with minTF freq") {
+    val df = sqlContext.createDataFrame(Seq(
+      (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
+      (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
+      (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))),
+      (3, split("e e e e e"), Vectors.sparse(4, Seq())))
+    ).toDF("id", "words", "expected")
+
+    // minTF: count
+    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
+      .setInputCol("words")
+      .setOutputCol("features")
+      .setMinTF(0.3)
+    cv.transform(df).select("features", "expected").collect().foreach {
+      case Row(features: Vector, expected: Vector) =>
+        assert(features ~== expected absTol 1e-14)
+    }
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala
deleted file mode 100644
index e90d9d4ef2..0000000000
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.ml.feature
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.mllib.util.TestingUtils._
-
-class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
-
-  test("params") {
-    ParamsSuite.checkParams(new CountVectorizerModel(Array("empty")))
-  }
-
-  test("CountVectorizerModel common cases") {
-    val df = sqlContext.createDataFrame(Seq(
-      (0, "a b c d".split(" ").toSeq,
-        Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
-      (1, "a b b c d  a".split(" ").toSeq,
-        Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))),
-      (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq((0, 1.0)))),
-      (3, "".split(" ").toSeq, Vectors.sparse(4, Seq())), // empty string
-      (4, "a notInDict d".split(" ").toSeq,
-        Vectors.sparse(4, Seq((0, 1.0), (3, 1.0))))  // with words not in vocabulary
-    )).toDF("id", "words", "expected")
-    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
-      .setInputCol("words")
-      .setOutputCol("features")
-    val output = cv.transform(df).collect()
-    output.foreach { p =>
-      val features = p.getAs[Vector]("features")
-      val expected = p.getAs[Vector]("expected")
-      assert(features ~== expected absTol 1e-14)
-    }
-  }
-
-  test("CountVectorizerModel with minTermFreq") {
-    val df = sqlContext.createDataFrame(Seq(
-      (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
-      (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))),
-      (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())),
-      (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq())))
-    ).toDF("id", "words", "expected")
-    val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
-      .setInputCol("words")
-      .setOutputCol("features")
-      .setMinTermFreq(3)
-    val output = cv.transform(df).collect()
-    output.foreach { p =>
-      val features = p.getAs[Vector]("features")
-      val expected = p.getAs[Vector]("expected")
-      assert(features ~== expected absTol 1e-14)
-    }
-  }
-}
-
-
-- 
GitLab