Skip to content
Snippets Groups Projects
Commit 2a6590e5 authored by Joseph K. Bradley's avatar Joseph K. Bradley Committed by Xiangrui Meng
Browse files

[SPARK-9981] [ML] Made labels public for StringIndexerModel

Also added unit test for integration between StringIndexerModel and IndexToString

CC: holdenk We realized we should have left in your unit test (to catch the issue with removing the inverse() method), so this adds it back.  mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #8211 from jkbradley/stridx-labels.
parent 11ed2b18
No related branches found
No related tags found
No related merge requests found
......@@ -97,14 +97,17 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
/**
* :: Experimental ::
* Model fitted by [[StringIndexer]].
*
* NOTE: During transformation, if the input column does not exist,
* [[StringIndexerModel.transform]] would return the input dataset unmodified.
* This is a temporary fix for the case when target labels do not exist during prediction.
*
* @param labels Ordered list of labels, corresponding to indices to be assigned
*/
@Experimental
class StringIndexerModel (
override val uid: String,
labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels)
......
......@@ -147,4 +147,22 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(actual === expected)
}
}
test("StringIndexer, IndexToString are inverses") {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
val transformed = indexer.transform(df)
val idx2str = new IndexToString()
.setInputCol("labelIndex")
.setOutputCol("sameLabel")
.setLabels(indexer.labels)
idx2str.transform(transformed).select("label", "sameLabel").collect().foreach {
case Row(a: String, b: String) =>
assert(a === b)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment