Skip to content
Snippets Groups Projects
Commit 85b96372 authored by Xiangrui Meng's avatar Xiangrui Meng Committed by Joseph K. Bradley
Browse files

[SPARK-7219] [MLLIB] Output feature attributes in HashingTF

This PR updates `HashingTF` to output ML attributes that tell the number of features in the output column. We need to expand `UnaryTransformer` to support output metadata. A `df outputMetadata: Metadata` is not sufficient because the metadata may also depends on the input data. Though this is not true for `HashingTF`, I think it is reasonable to update `UnaryTransformer` in a separate PR. `checkParams` is added to verify common requirements for params. I will send a separate PR to use it in other test suites. jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #6308 from mengxr/SPARK-7219 and squashes the following commits:

9bd2922 [Xiangrui Meng] address comments
e82a68a [Xiangrui Meng] remove sqlContext from test suite
995535b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7219
2194703 [Xiangrui Meng] add test for attributes
178ae23 [Xiangrui Meng] update HashingTF with tests
91a6106 [Xiangrui Meng] WIP
parent f5db4b41
No related branches found
No related tags found
No related merge requests found
......@@ -18,22 +18,31 @@
package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{IntParam, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{udf, col}
import org.apache.spark.sql.types.{ArrayType, StructType}
/**
* :: AlphaComponent ::
* Maps a sequence of terms to their term frequencies using the hashing trick.
*/
@AlphaComponent
class HashingTF(override val uid: String) extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol {
def this() = this(Identifiable.randomUID("hashingTF"))
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
/**
* Number of features. Should be > 0.
* (default = 2^18^)
......@@ -50,10 +59,19 @@ class HashingTF(override val uid: String) extends UnaryTransformer[Iterable[_],
/** @group setParam */
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
override protected def createTransformFunc: Iterable[_] => Vector = {
override def transform(dataset: DataFrame): DataFrame = {
val outputSchema = transformSchema(dataset.schema)
val hashingTF = new feature.HashingTF($(numFeatures))
hashingTF.transform
val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
val metadata = outputSchema($(outputCol)).metadata
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
}
override protected def outputDataType: DataType = new VectorUDT()
override def transformSchema(schema: StructType): StructType = {
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[ArrayType],
s"The input column must be ArrayType, but got $inputType.")
val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
SchemaUtils.appendColumn(schema, attrGroup.toStructField())
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import org.scalatest.FunSuite
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
class HashingTFSuite extends FunSuite with MLlibTestSparkContext {
test("params") {
val hashingTF = new HashingTF
ParamsSuite.checkParams(hashingTF, 3)
}
test("hashingTF") {
val df = sqlContext.createDataFrame(Seq(
(0, "a a b b c d".split(" ").toSeq)
)).toDF("id", "words")
val n = 100
val hashingTF = new HashingTF()
.setInputCol("words")
.setOutputCol("features")
.setNumFeatures(n)
val output = hashingTF.transform(df)
val attrGroup = AttributeGroup.fromStructField(output.schema("features"))
require(attrGroup.numAttributes === Some(n))
val features = output.select("features").first().getAs[Vector](0)
// Assume perfect hash on "a", "b", "c", and "d".
def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n)
val expected = Vectors.sparse(n,
Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
assert(features ~== expected absTol 1e-14)
}
}
......@@ -201,3 +201,23 @@ class ParamsSuite extends FunSuite {
assert(inArray(1) && inArray(2) && !inArray(0))
}
}
object ParamsSuite extends FunSuite {
/**
* Checks common requirements for [[Params.params]]: 1) number of params; 2) params are ordered
* by names; 3) param parent has the same UID as the object's UID; 4) param name is the same as
* the param method name.
*/
def checkParams(obj: Params, expectedNumParams: Int): Unit = {
val params = obj.params
require(params.length === expectedNumParams,
s"Expect $expectedNumParams params but got ${params.length}: ${params.map(_.name).toSeq}.")
val paramNames = params.map(_.name)
require(paramNames === paramNames.sorted)
params.foreach { p =>
assert(p.parent === obj.uid)
assert(obj.getParam(p.name) === p)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment