Skip to content
Snippets Groups Projects
Commit 85941ecf authored by Menglong TAN's avatar Menglong TAN Committed by Joseph K. Bradley
Browse files

[SPARK-11569][ML] Fix StringIndexer to handle null value properly

## What changes were proposed in this pull request?

This PR is to enhance StringIndexer with NULL values handling.

Before the PR, StringIndexer will throw an exception when encounters NULL values.
With this PR:
- handleInvalid=error: Throw an exception as before
- handleInvalid=skip: Skip null values as well as unseen labels
- handleInvalid=keep: Give null values an additional index as well as unseen labels

BTW, I noticed someone was trying to solve the same problem ( #9920 ) but seems getting no progress or response for a long time. Would you mind to give me a chance to solve it ? I'm eager to help. :-)

## How was this patch tested?

new unit tests

Author: Menglong TAN <tanmenglong@renrenche.com>
Author: Menglong TAN <tanmenglong@gmail.com>

Closes #17233 from crackcell/11569_StringIndexer_NULL.
parent d4a637cd
No related branches found
No related tags found
No related merge requests found
...@@ -39,20 +39,21 @@ import org.apache.spark.util.collection.OpenHashMap ...@@ -39,20 +39,21 @@ import org.apache.spark.util.collection.OpenHashMap
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
/** /**
* Param for how to handle unseen labels. Options are 'skip' (filter out rows with * Param for how to handle invalid data (unseen labels or NULL values).
* unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional * Options are 'skip' (filter out rows with invalid data),
* bucket, at index numLabels. * 'error' (throw an error), or 'keep' (put invalid data in a special additional
* bucket, at index numLabels).
* Default: "error" * Default: "error"
* @group param * @group param
*/ */
@Since("1.6.0") @Since("1.6.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
"unseen labels. Options are 'skip' (filter out rows with unseen labels), " + "invalid data (unseen labels or NULL values). " +
"error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " + "Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
"at index numLabels).", "or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) setDefault(handleInvalid, StringIndexer.ERROR_INVALID)
/** @group getParam */ /** @group getParam */
@Since("1.6.0") @Since("1.6.0")
...@@ -106,7 +107,7 @@ class StringIndexer @Since("1.4.0") ( ...@@ -106,7 +107,7 @@ class StringIndexer @Since("1.4.0") (
@Since("2.0.0") @Since("2.0.0")
override def fit(dataset: Dataset[_]): StringIndexerModel = { override def fit(dataset: Dataset[_]): StringIndexerModel = {
transformSchema(dataset.schema, logging = true) transformSchema(dataset.schema, logging = true)
val counts = dataset.select(col($(inputCol)).cast(StringType)) val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))
.rdd .rdd
.map(_.getString(0)) .map(_.getString(0))
.countByValue() .countByValue()
...@@ -125,11 +126,11 @@ class StringIndexer @Since("1.4.0") ( ...@@ -125,11 +126,11 @@ class StringIndexer @Since("1.4.0") (
@Since("1.6.0") @Since("1.6.0")
object StringIndexer extends DefaultParamsReadable[StringIndexer] { object StringIndexer extends DefaultParamsReadable[StringIndexer] {
private[feature] val SKIP_UNSEEN_LABEL: String = "skip" private[feature] val SKIP_INVALID: String = "skip"
private[feature] val ERROR_UNSEEN_LABEL: String = "error" private[feature] val ERROR_INVALID: String = "error"
private[feature] val KEEP_UNSEEN_LABEL: String = "keep" private[feature] val KEEP_INVALID: String = "keep"
private[feature] val supportedHandleInvalids: Array[String] = private[feature] val supportedHandleInvalids: Array[String] =
Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL) Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
@Since("1.6.0") @Since("1.6.0")
override def load(path: String): StringIndexer = super.load(path) override def load(path: String): StringIndexer = super.load(path)
...@@ -188,7 +189,7 @@ class StringIndexerModel ( ...@@ -188,7 +189,7 @@ class StringIndexerModel (
transformSchema(dataset.schema, logging = true) transformSchema(dataset.schema, logging = true)
val filteredLabels = getHandleInvalid match { val filteredLabels = getHandleInvalid match {
case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown" case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
case _ => labels case _ => labels
} }
...@@ -196,22 +197,31 @@ class StringIndexerModel ( ...@@ -196,22 +197,31 @@ class StringIndexerModel (
.withName($(outputCol)).withValues(filteredLabels).toMetadata() .withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out. // If we are skipping invalid records, filter them out.
val (filteredDataset, keepInvalid) = getHandleInvalid match { val (filteredDataset, keepInvalid) = getHandleInvalid match {
case StringIndexer.SKIP_UNSEEN_LABEL => case StringIndexer.SKIP_INVALID =>
val filterer = udf { label: String => val filterer = udf { label: String =>
labelToIndex.contains(label) labelToIndex.contains(label)
} }
(dataset.where(filterer(dataset($(inputCol)))), false) (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false)
case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL) case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID)
} }
val indexer = udf { label: String => val indexer = udf { label: String =>
if (labelToIndex.contains(label)) { if (label == null) {
labelToIndex(label) if (keepInvalid) {
} else if (keepInvalid) { labels.length
labels.length } else {
throw new SparkException("StringIndexer encountered NULL value. To handle or skip " +
"NULLS, try setting StringIndexer.handleInvalid.")
}
} else { } else {
throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + if (labelToIndex.contains(label)) {
s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.") labelToIndex(label)
} else if (keepInvalid) {
labels.length
} else {
throw new SparkException(s"Unseen label: $label. To handle unseen labels, " +
s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.")
}
} }
} }
......
...@@ -122,6 +122,51 @@ class StringIndexerSuite ...@@ -122,6 +122,51 @@ class StringIndexerSuite
assert(output === expected) assert(output === expected)
} }
test("StringIndexer with NULLs") {
val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null))
val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, null))
val df = data.toDF("id", "label")
val df2 = data2.toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
withClue("StringIndexer should throw error when setHandleInvalid=error " +
"when given NULL values") {
intercept[SparkException] {
indexer.setHandleInvalid("error")
indexer.fit(df).transform(df2).collect()
}
}
indexer.setHandleInvalid("skip")
val transformedSkip = indexer.fit(df).transform(df2)
val attrSkip = Attribute
.fromStructField(transformedSkip.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attrSkip.values.get === Array("b", "a"))
val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// a -> 1, b -> 0
val expectedSkip = Set((0, 1.0), (1, 0.0))
assert(outputSkip === expectedSkip)
indexer.setHandleInvalid("keep")
val transformedKeep = indexer.fit(df).transform(df2)
val attrKeep = Attribute
.fromStructField(transformedKeep.schema("labelIndex"))
.asInstanceOf[NominalAttribute]
assert(attrKeep.values.get === Array("b", "a", "__unknown"))
val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
(r.getInt(0), r.getDouble(1))
}.collect().toSet
// a -> 1, b -> 0, null -> 2
val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0))
assert(outputKeep === expectedKeep)
}
test("StringIndexerModel should keep silent if the input column does not exist.") { test("StringIndexerModel should keep silent if the input column does not exist.") {
val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
.setInputCol("label") .setInputCol("label")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment