diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 2f1220a752783652085b90f61985584ff2e65ce8..75b1a74ee8c7cfb6b746f29158d61c3ca8551ecc 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -374,6 +374,10 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -409,7 +413,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, - maxMemoryInMB = 256, cacheNodeIds = FALSE) { + maxMemoryInMB = 256, cacheNodeIds = FALSE, + handleInvalid = c("error", "keep", "skip")) { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -430,6 +435,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo new("RandomForestRegressionModel", jobj = jobj) }, classification = { + handleInvalid <- match.arg(handleInvalid) if (is.null(impurity)) impurity <- "gini" impurity <- match.arg(impurity, c("gini", "entropy")) jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", @@ -439,7 +445,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo as.numeric(minInfoGain), as.integer(checkpointInterval), as.character(featureSubsetStrategy), seed, as.numeric(subsamplingRate), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + as.integer(maxMemoryInMB), as.logical(cacheNodeIds), + handleInvalid) new("RandomForestClassificationModel", jobj = jobj) } ) diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R index 9b3fc8d270b256ca279ce36314f8580c0067acce..66a0693a59a529765b6d156f9179bb937631a8f9 100644 --- a/R/pkg/tests/fulltests/test_mllib_tree.R +++ b/R/pkg/tests/fulltests/test_mllib_tree.R @@ -212,6 +212,23 @@ test_that("spark.randomForest", { expect_equal(length(grep("1.0", predictions)), 50) expect_equal(length(grep("2.0", predictions)), 50) + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.randomForest(traindf, clicked ~ ., type = "classification", + maxDepth = 10, maxBins = 10, numTrees = 10) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.randomForest(traindf, clicked ~ ., type = "classification", + maxDepth = 10, maxBins = 10, numTrees = 10, + handleInvalid = "skip") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") + # spark.randomForest classification can work on libsvm data if (windows_with_hadoop()) { data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 4b44878784c9055494db57c0acddce74ebc641c2..61aa6463bb6daec9b5da801058aaef0d9194c127 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -132,6 +132,30 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) @Since("1.5.0") def getFormula: String = $(formula) + /** + * Param for how to handle invalid data (unseen labels or NULL values). + * Options are 'skip' (filter out rows with invalid data), + * 'error' (throw an error), or 'keep' (put invalid data in a special additional + * bucket, at index numLabels). + * Default: "error" + * @group param + */ + @Since("2.3.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " + + "invalid data (unseen labels or NULL values). " + + "Options are 'skip' (filter out rows with invalid data), error (throw an error), " + + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", + ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) + setDefault(handleInvalid, StringIndexer.ERROR_INVALID) + + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** @group getParam */ + @Since("2.3.0") + def getHandleInvalid: String = $(handleInvalid) + /** @group setParam */ @Since("1.5.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) @@ -197,6 +221,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) .setInputCol(term) .setOutputCol(indexCol) .setStringOrderType($(stringIndexerOrderType)) + .setHandleInvalid($(handleInvalid)) prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) case _ => diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index 8a83d4e980f7b496b668f177549623ae6c6a0ba0..132345fb9a6d9ae74826de58a7479f51aeb593c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -78,11 +78,13 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC seed: String, subsamplingRate: Double, maxMemoryInMB: Int, - cacheNodeIds: Boolean): RandomForestClassifierWrapper = { + cacheNodeIds: Boolean, + handleInvalid: String): RandomForestClassifierWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 806a92760c8b6b2ee5fc65d29adef183d0579149..027b1fbc6657cd2b918086d15a649187b0aa2f56 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col