Skip to content
Snippets Groups Projects
Commit 354f4582 authored by Yuhao Yang's avatar Yuhao Yang Committed by Joseph K. Bradley
Browse files

[SPARK-9028] [ML] Add CountVectorizer as an estimator to generate CountVectorizerModel

jira: https://issues.apache.org/jira/browse/SPARK-9028

Add an estimator for CountVectorizerModel. The estimator will extract a vocabulary from document collections according to the term frequency.

I changed the meaning of minCount as a filter across the corpus. This aligns with Word2Vec and the similar parameter in SKlearn.

Author: Yuhao Yang <hhbyyh@gmail.com>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #7388 from hhbyyh/cvEstimator.
parent 1968276a
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.DataFrame
import org.apache.spark.util.collection.OpenHashMap
/**
* Params for [[CountVectorizer]] and [[CountVectorizerModel]].
*/
private[feature] trait CountVectorizerParams extends Params with HasInputCol with HasOutputCol {
/**
* Max size of the vocabulary.
* CountVectorizer will build a vocabulary that only considers the top
* vocabSize terms ordered by term frequency across the corpus.
*
* Default: 2^18^
* @group param
*/
val vocabSize: IntParam =
new IntParam(this, "vocabSize", "max size of the vocabulary", ParamValidators.gt(0))
/** @group getParam */
def getVocabSize: Int = $(vocabSize)
/**
* Specifies the minimum number of different documents a term must appear in to be included
* in the vocabulary.
* If this is an integer >= 1, this specifies the number of documents the term must appear in;
* if this is a double in [0,1), then this specifies the fraction of documents.
*
* Default: 1
* @group param
*/
val minDF: DoubleParam = new DoubleParam(this, "minDF", "Specifies the minimum number of" +
" different documents a term must appear in to be included in the vocabulary." +
" If this is an integer >= 1, this specifies the number of documents the term must" +
" appear in; if this is a double in [0,1), then this specifies the fraction of documents.",
ParamValidators.gtEq(0.0))
/** @group getParam */
def getMinDF: Double = $(minDF)
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
/**
* Filter to ignore rare words in a document. For each document, terms with
* frequency/count less than the given threshold are ignored.
* If this is an integer >= 1, then this specifies a count (of times the term must appear
* in the document);
* if this is a double in [0,1), then this specifies a fraction (out of the document's token
* count).
*
* Note that the parameter is only used in transform of [[CountVectorizerModel]] and does not
* affect fitting.
*
* Default: 1
* @group param
*/
val minTF: DoubleParam = new DoubleParam(this, "minTF", "Filter to ignore rare words in" +
" a document. For each document, terms with frequency/count less than the given threshold are" +
" ignored. If this is an integer >= 1, then this specifies a count (of times the term must" +
" appear in the document); if this is a double in [0,1), then this specifies a fraction (out" +
" of the document's token count). Note that the parameter is only used in transform of" +
" CountVectorizerModel and does not affect fitting.", ParamValidators.gtEq(0.0))
setDefault(minTF -> 1)
/** @group getParam */
def getMinTF: Double = $(minTF)
}
/**
* :: Experimental ::
* Extracts a vocabulary from document collections and generates a [[CountVectorizerModel]].
*/
@Experimental
class CountVectorizer(override val uid: String)
extends Estimator[CountVectorizerModel] with CountVectorizerParams {
def this() = this(Identifiable.randomUID("cntVec"))
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
/** @group setParam */
def setVocabSize(value: Int): this.type = set(vocabSize, value)
/** @group setParam */
def setMinDF(value: Double): this.type = set(minDF, value)
/** @group setParam */
def setMinTF(value: Double): this.type = set(minTF, value)
setDefault(vocabSize -> (1 << 18), minDF -> 1)
override def fit(dataset: DataFrame): CountVectorizerModel = {
transformSchema(dataset.schema, logging = true)
val vocSize = $(vocabSize)
val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0))
val minDf = if ($(minDF) >= 1.0) {
$(minDF)
} else {
$(minDF) * input.cache().count()
}
val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) =>
val wc = new OpenHashMap[String, Long]
tokens.foreach { w =>
wc.changeValue(w, 1L, _ + 1L)
}
wc.map { case (word, count) => (word, (count, 1)) }
}.reduceByKey { case ((wc1, df1), (wc2, df2)) =>
(wc1 + wc2, df1 + df2)
}.filter { case (word, (wc, df)) =>
df >= minDf
}.map { case (word, (count, dfCount)) =>
(word, count)
}.cache()
val fullVocabSize = wordCounts.count()
val vocab: Array[String] = {
val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) {
// Use all terms
wordCounts.collect().sortBy(-_._2)
} else {
// Sort terms to select vocab
wordCounts.sortBy(_._2, ascending = false).take(vocSize)
}
tmpSortedWC.map(_._1)
}
require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.")
copyValues(new CountVectorizerModel(uid, vocab).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra)
}
/**
* :: Experimental ::
* Converts a text document to a sparse vector of token counts.
* @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted.
*/
@Experimental
class CountVectorizerModel(override val uid: String, val vocabulary: Array[String])
extends Model[CountVectorizerModel] with CountVectorizerParams {
def this(vocabulary: Array[String]) = {
this(Identifiable.randomUID("cntVecModel"), vocabulary)
set(vocabSize, vocabulary.length)
}
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
/** @group setParam */
def setMinTF(value: Double): this.type = set(minTF, value)
/** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */
private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None
override def transform(dataset: DataFrame): DataFrame = {
if (broadcastDict.isEmpty) {
val dict = vocabulary.zipWithIndex.toMap
broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict))
}
val dictBr = broadcastDict.get
val minTf = $(minTF)
val vectorizer = udf { (document: Seq[String]) =>
val termCounts = new OpenHashMap[Int, Double]
var tokenCount = 0L
document.foreach { term =>
dictBr.value.get(term) match {
case Some(index) => termCounts.changeValue(index, 1.0, _ + 1.0)
case None => // ignore terms not in the vocabulary
}
tokenCount += 1
}
val effectiveMinTF = if (minTf >= 1.0) {
minTf
} else {
tokenCount * minTf
}
Vectors.sparse(dictBr.value.size, termCounts.filter(_._2 >= effectiveMinTF).toSeq)
}
dataset.withColumn($(outputCol), vectorizer(col($(inputCol))))
}
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}
override def copy(extra: ParamMap): CountVectorizerModel = {
val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent)
copyValues(copied, extra)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import scala.collection.mutable
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector}
import org.apache.spark.sql.types.{StringType, ArrayType, DataType}
/**
* :: Experimental ::
* Converts a text document to a sparse vector of token counts.
* @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted.
*/
@Experimental
class CountVectorizerModel (override val uid: String, val vocabulary: Array[String])
extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] {
def this(vocabulary: Array[String]) =
this(Identifiable.randomUID("cntVec"), vocabulary)
/**
* Corpus-specific filter to ignore scarce words in a document. For each document, terms with
* frequency (count) less than the given threshold are ignored.
* Default: 1
* @group param
*/
val minTermFreq: IntParam = new IntParam(this, "minTermFreq",
"minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " +
"terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1))
/** @group setParam */
def setMinTermFreq(value: Int): this.type = set(minTermFreq, value)
/** @group getParam */
def getMinTermFreq: Int = $(minTermFreq)
setDefault(minTermFreq -> 1)
override protected def createTransformFunc: Seq[String] => Vector = {
val dict = vocabulary.zipWithIndex.toMap
document =>
val termCounts = mutable.HashMap.empty[Int, Double]
document.foreach { term =>
dict.get(term) match {
case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0)
case None => // ignore terms not in the vocabulary
}
}
Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermFreq)).toSeq)
}
override protected def validateInputType(inputType: DataType): Unit = {
require(inputType.sameType(ArrayType(StringType)),
s"Input type must be ArrayType(StringType) but got $inputType.")
}
override protected def outputDataType: DataType = new VectorUDT()
override def copy(extra: ParamMap): CountVectorizerModel = {
val copied = new CountVectorizerModel(uid, vocabulary)
copyValues(copied, extra)
}
}
......@@ -21,6 +21,7 @@ import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
......@@ -28,46 +29,139 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
ParamsSuite.checkParams(new CountVectorizerModel(Array("empty")))
}
private def split(s: String): Seq[String] = s.split("\\s+")
test("CountVectorizerModel common cases") {
val df = sqlContext.createDataFrame(Seq(
(0, "a b c d".split(" ").toSeq,
(0, split("a b c d"),
Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
(1, "a b b c d a".split(" ").toSeq,
(1, split("a b b c d a"),
Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))),
(2, "a".split(" ").toSeq, Vectors.sparse(4, Seq((0, 1.0)))),
(3, "".split(" ").toSeq, Vectors.sparse(4, Seq())), // empty string
(4, "a notInDict d".split(" ").toSeq,
(2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))),
(3, split(""), Vectors.sparse(4, Seq())), // empty string
(4, split("a notInDict d"),
Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary
)).toDF("id", "words", "expected")
val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
.setInputCol("words")
.setOutputCol("features")
val output = cv.transform(df).collect()
output.foreach { p =>
val features = p.getAs[Vector]("features")
val expected = p.getAs[Vector]("expected")
assert(features ~== expected absTol 1e-14)
cv.transform(df).select("features", "expected").collect().foreach {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
}
test("CountVectorizer common cases") {
val df = sqlContext.createDataFrame(Seq(
(0, split("a b c d e"),
Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
(1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
(2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))),
(3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
).toDF("id", "words", "expected")
val cv = new CountVectorizer()
.setInputCol("words")
.setOutputCol("features")
.fit(df)
assert(cv.vocabulary === Array("a", "b", "c", "d", "e"))
cv.transform(df).select("features", "expected").collect().foreach {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
}
test("CountVectorizer vocabSize and minDF") {
val df = sqlContext.createDataFrame(Seq(
(0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
(1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
(2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
(3, split("a"), Vectors.sparse(3, Seq((0, 1.0)))))
).toDF("id", "words", "expected")
val cvModel = new CountVectorizer()
.setInputCol("words")
.setOutputCol("features")
.setVocabSize(3) // limit vocab size to 3
.fit(df)
assert(cvModel.vocabulary === Array("a", "b", "c"))
// minDF: ignore terms with count less than 3
val cvModel2 = new CountVectorizer()
.setInputCol("words")
.setOutputCol("features")
.setMinDF(3)
.fit(df)
assert(cvModel2.vocabulary === Array("a", "b"))
cvModel2.transform(df).select("features", "expected").collect().foreach {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
// minDF: ignore terms with freq < 0.75
val cvModel3 = new CountVectorizer()
.setInputCol("words")
.setOutputCol("features")
.setMinDF(3.0 / df.count())
.fit(df)
assert(cvModel3.vocabulary === Array("a", "b"))
cvModel3.transform(df).select("features", "expected").collect().foreach {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
}
test("CountVectorizerModel with minTermFreq") {
test("CountVectorizer throws exception when vocab is empty") {
intercept[IllegalArgumentException] {
val df = sqlContext.createDataFrame(Seq(
(0, split("a a b b c c")),
(1, split("aa bb cc")))
).toDF("id", "words")
val cvModel = new CountVectorizer()
.setInputCol("words")
.setOutputCol("features")
.setVocabSize(3) // limit vocab size to 3
.setMinDF(3)
.fit(df)
}
}
test("CountVectorizerModel with minTF count") {
val df = sqlContext.createDataFrame(Seq(
(0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
(1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))),
(2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())),
(3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq())))
(0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
(1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
(2, split("a"), Vectors.sparse(4, Seq())),
(3, split("e e e e e"), Vectors.sparse(4, Seq())))
).toDF("id", "words", "expected")
// minTF: count
val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
.setInputCol("words")
.setOutputCol("features")
.setMinTermFreq(3)
val output = cv.transform(df).collect()
output.foreach { p =>
val features = p.getAs[Vector]("features")
val expected = p.getAs[Vector]("expected")
assert(features ~== expected absTol 1e-14)
.setMinTF(3)
cv.transform(df).select("features", "expected").collect().foreach {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
}
}
test("CountVectorizerModel with minTF freq") {
val df = sqlContext.createDataFrame(Seq(
(0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
(1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
(2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))),
(3, split("e e e e e"), Vectors.sparse(4, Seq())))
).toDF("id", "words", "expected")
// minTF: count
val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
.setInputCol("words")
.setOutputCol("features")
.setMinTF(0.3)
cv.transform(df).select("features", "expected").collect().foreach {
case Row(features: Vector, expected: Vector) =>
assert(features ~== expected absTol 1e-14)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment