From a7b46c627b5d2461257f337139a29f23350e0c77 Mon Sep 17 00:00:00 2001 From: wangmiao1981 <wm624@hotmail.com> Date: Fri, 7 Jul 2017 23:51:32 -0700 Subject: [PATCH] [SPARK-20307][SPARKR] SparkR: pass on setHandleInvalid to spark.mllib functions that use StringIndexer ## What changes were proposed in this pull request? For randomForest classifier, if test data contains unseen labels, it will throw an error. The StringIndexer already has the handleInvalid logic. The patch add a new method to set the underlying StringIndexer handleInvalid logic. This patch should also apply to other classifiers. This PR focuses on the main logic and randomForest classifier. I will do follow-up PR for other classifiers. ## How was this patch tested? Add a new unit test based on the error case in the JIRA. Author: wangmiao1981 <wm624@hotmail.com> Closes #18496 from wangmiao1981/handle. --- R/pkg/R/mllib_tree.R | 11 ++++++-- R/pkg/tests/fulltests/test_mllib_tree.R | 17 +++++++++++++ .../apache/spark/ml/feature/RFormula.scala | 25 +++++++++++++++++++ .../r/RandomForestClassificationWrapper.scala | 4 ++- .../spark/ml/feature/StringIndexerSuite.scala | 2 +- 5 files changed, 55 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 2f1220a752..75b1a74ee8 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -374,6 +374,10 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -409,7 +413,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, - maxMemoryInMB = 256, cacheNodeIds = FALSE) { + maxMemoryInMB = 256, cacheNodeIds = FALSE, + handleInvalid = c("error", "keep", "skip")) { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -430,6 +435,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo new("RandomForestRegressionModel", jobj = jobj) }, classification = { + handleInvalid <- match.arg(handleInvalid) if (is.null(impurity)) impurity <- "gini" impurity <- match.arg(impurity, c("gini", "entropy")) jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", @@ -439,7 +445,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo as.numeric(minInfoGain), as.integer(checkpointInterval), as.character(featureSubsetStrategy), seed, as.numeric(subsamplingRate), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + as.integer(maxMemoryInMB), as.logical(cacheNodeIds), + handleInvalid) new("RandomForestClassificationModel", jobj = jobj) } ) diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R index 9b3fc8d270..66a0693a59 100644 --- a/R/pkg/tests/fulltests/test_mllib_tree.R +++ b/R/pkg/tests/fulltests/test_mllib_tree.R @@ -212,6 +212,23 @@ test_that("spark.randomForest", { expect_equal(length(grep("1.0", predictions)), 50) expect_equal(length(grep("2.0", predictions)), 50) + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.randomForest(traindf, clicked ~ ., type = "classification", + maxDepth = 10, maxBins = 10, numTrees = 10) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.randomForest(traindf, clicked ~ ., type = "classification", + maxDepth = 10, maxBins = 10, numTrees = 10, + handleInvalid = "skip") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") + # spark.randomForest classification can work on libsvm data if (windows_with_hadoop()) { data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 4b44878784..61aa6463bb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -132,6 +132,30 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") def getFormula: String = $(formula) + /** + * Param for how to handle invalid data (unseen labels or NULL values). + * Options are 'skip' (filter out rows with invalid data), + * 'error' (throw an error), or 'keep' (put invalid data in a special additional + * bucket, at index numLabels). + * Default: "error" + * @group param + */ + @Since("2.3.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " + + "invalid data (unseen labels or NULL values). " + + "Options are 'skip' (filter out rows with invalid data), error (throw an error), " + + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", + ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) + setDefault(handleInvalid, StringIndexer.ERROR_INVALID) + + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** @group getParam */ + @Since("2.3.0") + def getHandleInvalid: String = $(handleInvalid) + /** @group setParam */ @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -197,6 +221,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) .setInputCol(term) .setOutputCol(indexCol) .setStringOrderType($(stringIndexerOrderType)) + .setHandleInvalid($(handleInvalid)) prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) case _ => diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 8a83d4e980..132345fb9a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -78,11 +78,13 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC seed: String, subsamplingRate: Double, maxMemoryInMB: Int, - cacheNodeIds: Boolean): RandomForestClassifierWrapper = { + cacheNodeIds: Boolean, + handleInvalid: String): RandomForestClassifierWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 806a92760c..027b1fbc66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -- GitLab