Skip to content
Snippets Groups Projects
Commit a7b46c62 authored by wangmiao1981's avatar wangmiao1981 Committed by Felix Cheung
Browse files

[SPARK-20307][SPARKR] SparkR: pass on setHandleInvalid to spark.mllib...

[SPARK-20307][SPARKR] SparkR: pass on setHandleInvalid to spark.mllib functions that use StringIndexer

## What changes were proposed in this pull request?

For randomForest classifier, if test data contains unseen labels, it will throw an error. The StringIndexer already has the handleInvalid logic. The patch add a new method to set the underlying StringIndexer handleInvalid logic.

This patch should also apply to other classifiers. This PR focuses on the main logic and randomForest classifier. I will do follow-up PR for other classifiers.

## How was this patch tested?

Add a new unit test based on the error case in the JIRA.

Author: wangmiao1981 <wm624@hotmail.com>

Closes #18496 from wangmiao1981/handle.
parent d0bfc673
No related branches found
No related tags found
No related merge requests found
......@@ -374,6 +374,10 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara
#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
#' can speed up training of deeper trees. Users can set how often should the
#' cache be checkpointed or disable it by setting checkpointInterval.
#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model.
#' Supported options: "skip" (filter out rows with invalid data),
#' "error" (throw an error), "keep" (put invalid data in a special additional
#' bucket, at index numLabels). Default is "error".
#' @param ... additional arguments passed to the method.
#' @aliases spark.randomForest,SparkDataFrame,formula-method
#' @return \code{spark.randomForest} returns a fitted Random Forest model.
......@@ -409,7 +413,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL,
featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0,
minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
maxMemoryInMB = 256, cacheNodeIds = FALSE) {
maxMemoryInMB = 256, cacheNodeIds = FALSE,
handleInvalid = c("error", "keep", "skip")) {
type <- match.arg(type)
formula <- paste(deparse(formula), collapse = "")
if (!is.null(seed)) {
......@@ -430,6 +435,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
new("RandomForestRegressionModel", jobj = jobj)
},
classification = {
handleInvalid <- match.arg(handleInvalid)
if (is.null(impurity)) impurity <- "gini"
impurity <- match.arg(impurity, c("gini", "entropy"))
jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper",
......@@ -439,7 +445,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
as.numeric(minInfoGain), as.integer(checkpointInterval),
as.character(featureSubsetStrategy), seed,
as.numeric(subsamplingRate),
as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
handleInvalid)
new("RandomForestClassificationModel", jobj = jobj)
}
)
......
......@@ -212,6 +212,23 @@ test_that("spark.randomForest", {
expect_equal(length(grep("1.0", predictions)), 50)
expect_equal(length(grep("2.0", predictions)), 50)
# Test unseen labels
data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
someString = base::sample(c("this", "that"), 10, replace = TRUE),
stringsAsFactors = FALSE)
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
traindf <- as.DataFrame(data[trainidxs, ])
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
maxDepth = 10, maxBins = 10, numTrees = 10)
predictions <- predict(model, testdf)
expect_error(collect(predictions))
model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
maxDepth = 10, maxBins = 10, numTrees = 10,
handleInvalid = "skip")
predictions <- predict(model, testdf)
expect_equal(class(collect(predictions)$clicked[1]), "character")
# spark.randomForest classification can work on libsvm data
if (windows_with_hadoop()) {
data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
......
......@@ -132,6 +132,30 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
@Since("1.5.0")
def getFormula: String = $(formula)
/**
* Param for how to handle invalid data (unseen labels or NULL values).
* Options are 'skip' (filter out rows with invalid data),
* 'error' (throw an error), or 'keep' (put invalid data in a special additional
* bucket, at index numLabels).
* Default: "error"
* @group param
*/
@Since("2.3.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " +
"invalid data (unseen labels or NULL values). " +
"Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
"or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
setDefault(handleInvalid, StringIndexer.ERROR_INVALID)
/** @group setParam */
@Since("2.3.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
/** @group getParam */
@Since("2.3.0")
def getHandleInvalid: String = $(handleInvalid)
/** @group setParam */
@Since("1.5.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
......@@ -197,6 +221,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
.setInputCol(term)
.setOutputCol(indexCol)
.setStringOrderType($(stringIndexerOrderType))
.setHandleInvalid($(handleInvalid))
prefixesToRewrite(indexCol + "_") = term + "_"
(term, indexCol)
case _ =>
......
......@@ -78,11 +78,13 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
seed: String,
subsamplingRate: Double,
maxMemoryInMB: Int,
cacheNodeIds: Boolean): RandomForestClassifierWrapper = {
cacheNodeIds: Boolean,
handleInvalid: String): RandomForestClassifierWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
.setHandleInvalid(handleInvalid)
checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
......
......@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment