Skip to content
Snippets Groups Projects
Commit 6c5858bc authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-9922] [ML] rename StringIndexerReverse to IndexToString

What `StringIndexerInverse` does is not strictly associated with `StringIndexer`, and the name is not clearly describing the transformation. Renaming to `IndexToString` might be better.

~~I also changed `invert` to `inverse` without arguments. `inputCol` and `outputCol` could be set after.~~
I also removed `invert`.

jkbradley holdenk

Author: Xiangrui Meng <meng@databricks.com>

Closes #8152 from mengxr/SPARK-9922.
parent c2520f50
No related branches found
No related tags found
No related merge requests found
......@@ -24,7 +24,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
......@@ -59,6 +59,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
* If the input column is numeric, we cast it to string and index the string values.
* The indices are in [0, numLabels), ordered by label frequencies.
* So the most frequent label gets index 0.
*
* @see [[IndexToString]] for the inverse transformation
*/
@Experimental
class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel]
......@@ -170,34 +172,24 @@ class StringIndexerModel private[ml] (
val copied = new StringIndexerModel(uid, labels)
copyValues(copied, extra).setParent(parent)
}
/**
* Return a model to perform the inverse transformation.
* Note: By default we keep the original columns during this transformation, so the inverse
* should only be used on new columns such as predicted labels.
*/
def invert(inputCol: String, outputCol: String): StringIndexerInverse = {
new StringIndexerInverse()
.setInputCol(inputCol)
.setOutputCol(outputCol)
.setLabels(labels)
}
}
/**
* :: Experimental ::
* Transform a provided column back to the original input types using either the metadata
* on the input column, or if provided using the labels supplied by the user.
* Note: By default we keep the original columns during this transformation,
* so the inverse should only be used on new columns such as predicted labels.
* A [[Transformer]] that maps a column of string indices back to a new column of corresponding
* string values using either the ML attributes of the input column, or if provided using the labels
* supplied by the user.
* All original columns are kept during transformation.
*
* @see [[StringIndexer]] for converting strings into indices
*/
@Experimental
class StringIndexerInverse private[ml] (
class IndexToString private[ml] (
override val uid: String) extends Transformer
with HasInputCol with HasOutputCol {
def this() =
this(Identifiable.randomUID("strIdxInv"))
this(Identifiable.randomUID("idxToStr"))
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
......@@ -257,7 +249,7 @@ class StringIndexerInverse private[ml] (
}
val indexer = udf { index: Double =>
val idx = index.toInt
if (0 <= idx && idx < values.size) {
if (0 <= idx && idx < values.length) {
values(idx)
} else {
throw new SparkException(s"Unseen index: $index ??")
......@@ -268,7 +260,7 @@ class StringIndexerInverse private[ml] (
indexer(dataset($(inputCol)).cast(DoubleType)).as(outputColName))
}
override def copy(extra: ParamMap): StringIndexerInverse = {
override def copy(extra: ParamMap): IndexToString = {
defaultCopy(extra)
}
}
......@@ -17,12 +17,13 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
......@@ -53,19 +54,6 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
// a -> 0, b -> 2, c -> 1
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
// convert reverse our transform
val reversed = indexer.invert("labelIndex", "label2")
.transform(transformed)
.select("id", "label2")
assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
reversed.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
// Check invert using only metadata
val inverse2 = new StringIndexerInverse()
.setInputCol("labelIndex")
.setOutputCol("label2")
val reversed2 = inverse2.transform(transformed).select("id", "label2")
assert(df.collect().map(r => (r.getInt(0), r.getString(1))).toSet ===
reversed2.collect().map(r => (r.getInt(0), r.getString(1))).toSet)
}
test("StringIndexerUnseen") {
......@@ -125,4 +113,36 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
val df = sqlContext.range(0L, 10L)
assert(indexerModel.transform(df).eq(df))
}
test("IndexToString params") {
val idxToStr = new IndexToString()
ParamsSuite.checkParams(idxToStr)
}
test("IndexToString.transform") {
val labels = Array("a", "b", "c")
val df0 = sqlContext.createDataFrame(Seq(
(0, "a"), (1, "b"), (2, "c"), (0, "a")
)).toDF("index", "expected")
val idxToStr0 = new IndexToString()
.setInputCol("index")
.setOutputCol("actual")
.setLabels(labels)
idxToStr0.transform(df0).select("actual", "expected").collect().foreach {
case Row(actual, expected) =>
assert(actual === expected)
}
val attr = NominalAttribute.defaultAttr.withValues(labels)
val df1 = df0.select(col("index").as("indexWithAttr", attr.toMetadata()), col("expected"))
val idxToStr1 = new IndexToString()
.setInputCol("indexWithAttr")
.setOutputCol("actual")
idxToStr1.transform(df1).select("actual", "expected").collect().foreach {
case Row(actual, expected) =>
assert(actual === expected)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment