Skip to content
Snippets Groups Projects
Commit b2a22a65 authored by Xiangrui Meng's avatar Xiangrui Meng Committed by Joseph K. Bradley
Browse files

[SPARK-8051] [MLLIB] make StringIndexerModel silent if input column does not exist


This is just a workaround to a bigger problem. Some pipeline stages may not be effective during prediction, and they should not complain about missing required columns, e.g. `StringIndexerModel`. jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #6595 from mengxr/SPARK-8051 and squashes the following commits:

b6a36b9 [Xiangrui Meng] add doc
f143fd4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-8051
8ee7c7e [Xiangrui Meng] use SparkFunSuite
e112394 [Xiangrui Meng] make StringIndexerModel silent if input column does not exist

(cherry picked from commit 26c9d7a0)
Signed-off-by: default avatarJoseph K. Bradley <joseph@databricks.com>
parent ca21fff7
No related branches found
No related tags found
No related merge requests found
...@@ -88,6 +88,9 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod ...@@ -88,6 +88,9 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
/** /**
* :: Experimental :: * :: Experimental ::
* Model fitted by [[StringIndexer]]. * Model fitted by [[StringIndexer]].
* NOTE: During transformation, if the input column does not exist,
* [[StringIndexerModel.transform]] would return the input dataset unmodified.
* This is a temporary fix for the case when target labels do not exist during prediction.
*/ */
@Experimental @Experimental
class StringIndexerModel private[ml] ( class StringIndexerModel private[ml] (
...@@ -112,6 +115,12 @@ class StringIndexerModel private[ml] ( ...@@ -112,6 +115,12 @@ class StringIndexerModel private[ml] (
def setOutputCol(value: String): this.type = set(outputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame): DataFrame = { override def transform(dataset: DataFrame): DataFrame = {
if (!dataset.schema.fieldNames.contains($(inputCol))) {
logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
"Skip StringIndexerModel.")
return dataset
}
val indexer = udf { label: String => val indexer = udf { label: String =>
if (labelToIndex.contains(label)) { if (labelToIndex.contains(label)) {
labelToIndex(label) labelToIndex(label)
...@@ -128,6 +137,11 @@ class StringIndexerModel private[ml] ( ...@@ -128,6 +137,11 @@ class StringIndexerModel private[ml] (
} }
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema) if (schema.fieldNames.contains($(inputCol))) {
validateAndTransformSchema(schema)
} else {
// If the input column does not exist during transformation, we skip StringIndexerModel.
schema
}
} }
} }
...@@ -61,4 +61,12 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { ...@@ -61,4 +61,12 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected) assert(output === expected)
} }
test("StringIndexerModel should keep silent if the input column does not exist.") {
val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
.setInputCol("label")
.setOutputCol("labelIndex")
val df = sqlContext.range(0L, 10L)
assert(indexerModel.transform(df).eq(df))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment