Skip to content
Snippets Groups Projects
Commit 685ddcf5 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-5886][ML] Add StringIndexer as a feature transformer

This PR adds string indexer, which takes a column of string labels and outputs a double column with labels indexed by their frequency.

TODOs:
- [x] store feature to index map in output metadata

Author: Xiangrui Meng <meng@databricks.com>

Closes #4735 from mengxr/SPARK-5886 and squashes the following commits:

d82575f [Xiangrui Meng] fix test
700e70f [Xiangrui Meng] rename LabelIndexer to StringIndexer
16a6f8c [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5886
457166e [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5886
f8b30f4 [Xiangrui Meng] update label indexer to output metadata
e81ec28 [Xiangrui Meng] Merge branch 'openhashmap-contains' into SPARK-5886-2
d6e6f1f [Xiangrui Meng] add contains to primitivekeyopenhashmap
748a69b [Xiangrui Meng] add contains to OpenHashMap
def3c5c [Xiangrui Meng] add LabelIndexer
parent d3792f54
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import org.apache.spark.SparkException
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.collection.OpenHashMap
/**
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
*/
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
checkInputColumn(schema, map(inputCol), StringType)
val inputFields = schema.fields
val outputColName = map(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
val attr = NominalAttribute.defaultAttr.withName(map(outputCol))
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
}
}
/**
* :: AlphaComponent ::
* A label indexer that maps a string column of labels to an ML column of label indices.
* The indices are in [0, numLabels), ordered by label frequencies.
* So the most frequent label gets index 0.
*/
@AlphaComponent
class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase {
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
// TODO: handle unseen labels
override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = {
val map = this.paramMap ++ paramMap
val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue()
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
val model = new StringIndexerModel(this, map, labels)
Params.inheritValues(map, this, model)
model
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
/**
* :: AlphaComponent ::
* Model fitted by [[StringIndexer]].
*/
@AlphaComponent
class StringIndexerModel private[ml] (
override val parent: StringIndexer,
override val fittingParamMap: ParamMap,
labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
private val labelToIndex: OpenHashMap[String, Double] = {
val n = labels.length
val map = new OpenHashMap[String, Double](n)
var i = 0
while (i < n) {
map.update(labels(i), i)
i += 1
}
map
}
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
val map = this.paramMap ++ paramMap
val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else {
// TODO: handle unseen labels
throw new SparkException(s"Unseen label: $label.")
}
}
val outputColName = map(outputCol)
val metadata = NominalAttribute.defaultAttr
.withName(outputColName).withValues(labels).toStructField().metadata
dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata))
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import org.scalatest.FunSuite
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.SQLContext
class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
private var sqlContext: SQLContext = _
override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
}
test("StringIndexer") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
val transformed = indexer.transform(df)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attr.values.get === Array("a", "c", "b"))
val output = transformed.select("id", "labelIndex").map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// a -> 0, b -> 2, c -> 1
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment