Skip to content
Snippets Groups Projects
Commit 6e57d57b authored by Xusen Yin's avatar Xusen Yin Committed by Xiangrui Meng
Browse files

[SPARK-6528] [ML] Add IDF transformer

See [SPARK-6528](https://issues.apache.org/jira/browse/SPARK-6528). Add IDF transformer in ML package.

Author: Xusen Yin <yinxusen@gmail.com>

Closes #5266 from yinxusen/SPARK-6528 and squashes the following commits:

741db31 [Xusen Yin] get param from new paramMap
d169967 [Xusen Yin] add final to param and IDF class
c9c3759 [Xusen Yin] simplify test suite
5867c09 [Xusen Yin] refine IDF transformer with new interfaces
7727cae [Xusen Yin] Merge branch 'master' into SPARK-6528
4338a37 [Xusen Yin] Merge branch 'master' into SPARK-6528
aef2cdf [Xusen Yin] add doc and group for param
5760b49 [Xusen Yin] fix code style
2add691 [Xusen Yin] fix code style and test
03fbecb [Xusen Yin] remove duplicated code
2aa4be0 [Xusen Yin] clean test suite
4802c67 [Xusen Yin] add IDF transformer and test suite
parent 78b39c7e
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StructType
/**
* Params for [[IDF]] and [[IDFModel]].
*/
private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol {
/**
* The minimum of documents in which a term should appear.
* @group param
*/
final val minDocFreq = new IntParam(
this, "minDocFreq", "minimum of documents in which a term should appear for filtering")
setDefault(minDocFreq -> 0)
/** @group getParam */
def getMinDocFreq: Int = getOrDefault(minDocFreq)
/** @group setParam */
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)
/**
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = extractParamMap(paramMap)
SchemaUtils.checkColumnType(schema, map(inputCol), new VectorUDT)
SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT)
}
}
/**
* :: AlphaComponent ::
* Compute the Inverse Document Frequency (IDF) given a collection of documents.
*/
@AlphaComponent
final class IDF extends Estimator[IDFModel] with IDFBase {
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
override def fit(dataset: DataFrame, paramMap: ParamMap): IDFModel = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = extractParamMap(paramMap)
val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
val idf = new feature.IDF(map(minDocFreq)).fit(input)
val model = new IDFModel(this, map, idf)
Params.inheritValues(map, this, model)
model
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
/**
* :: AlphaComponent ::
* Model fitted by [[IDF]].
*/
@AlphaComponent
class IDFModel private[ml] (
override val parent: IDF,
override val fittingParamMap: ParamMap,
idfModel: feature.IDFModel)
extends Model[IDFModel] with IDFBase {
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = extractParamMap(paramMap)
val idf = udf { vec: Vector => idfModel.transform(vec) }
dataset.withColumn(map(outputCol), idf(col(map(inputCol))))
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import org.scalatest.FunSuite
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext}
class IDFSuite extends FunSuite with MLlibTestSparkContext {
@transient var sqlContext: SQLContext = _
override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
}
def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = {
dataSet.map {
case data: DenseVector =>
val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y }
Vectors.dense(res)
case data: SparseVector =>
val res = data.indices.zip(data.values).map { case (id, value) =>
(id, value * model(id))
}
Vectors.sparse(data.size, res)
}
}
test("compute IDF with default parameter") {
val numOfFeatures = 4
val data = Array(
Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)),
Vectors.dense(0.0, 1.0, 2.0, 3.0),
Vectors.sparse(numOfFeatures, Array(1), Array(1.0))
)
val numOfData = data.size
val idf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
math.log((numOfData + 1.0) / (x + 1.0))
})
val expected = scaleDataWithIDF(data, idf)
val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
val idfModel = new IDF()
.setInputCol("features")
.setOutputCol("idfValue")
.fit(df)
idfModel.transform(df).select("idfValue", "expected").collect().foreach {
case Row(x: Vector, y: Vector) =>
assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
}
}
test("compute IDF with setter") {
val numOfFeatures = 4
val data = Array(
Vectors.sparse(numOfFeatures, Array(1, 3), Array(1.0, 2.0)),
Vectors.dense(0.0, 1.0, 2.0, 3.0),
Vectors.sparse(numOfFeatures, Array(1), Array(1.0))
)
val numOfData = data.size
val idf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
if (x > 0) math.log((numOfData + 1.0) / (x + 1.0)) else 0
})
val expected = scaleDataWithIDF(data, idf)
val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected")
val idfModel = new IDF()
.setInputCol("features")
.setOutputCol("idfValue")
.setMinDocFreq(1)
.fit(df)
idfModel.transform(df).select("idfValue", "expected").collect().foreach {
case Row(x: Vector, y: Vector) =>
assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment