diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R
index 2f1220a752783652085b90f61985584ff2e65ce8..75b1a74ee8c7cfb6b746f29158d61c3ca8551ecc 100644
--- a/R/pkg/R/mllib_tree.R
+++ b/R/pkg/R/mllib_tree.R
@@ -374,6 +374,10 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara
 #'                     nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching
 #'                     can speed up training of deeper trees. Users can set how often should the
 #'                     cache be checkpointed or disable it by setting checkpointInterval.
+#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model.
+#'        Supported options: "skip" (filter out rows with invalid data),
+#'                           "error" (throw an error), "keep" (put invalid data in a special additional
+#'                           bucket, at index numLabels). Default is "error".
 #' @param ... additional arguments passed to the method.
 #' @aliases spark.randomForest,SparkDataFrame,formula-method
 #' @return \code{spark.randomForest} returns a fitted Random Forest model.
@@ -409,7 +413,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
                    maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL,
                    featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0,
                    minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10,
-                   maxMemoryInMB = 256, cacheNodeIds = FALSE) {
+                   maxMemoryInMB = 256, cacheNodeIds = FALSE,
+                   handleInvalid = c("error", "keep", "skip")) {
             type <- match.arg(type)
             formula <- paste(deparse(formula), collapse = "")
             if (!is.null(seed)) {
@@ -430,6 +435,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
                      new("RandomForestRegressionModel", jobj = jobj)
                    },
                    classification = {
+                     handleInvalid <- match.arg(handleInvalid)
                      if (is.null(impurity)) impurity <- "gini"
                      impurity <- match.arg(impurity, c("gini", "entropy"))
                      jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper",
@@ -439,7 +445,8 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo
                                          as.numeric(minInfoGain), as.integer(checkpointInterval),
                                          as.character(featureSubsetStrategy), seed,
                                          as.numeric(subsamplingRate),
-                                         as.integer(maxMemoryInMB), as.logical(cacheNodeIds))
+                                         as.integer(maxMemoryInMB), as.logical(cacheNodeIds),
+                                         handleInvalid)
                      new("RandomForestClassificationModel", jobj = jobj)
                    }
             )
diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R
index 9b3fc8d270b256ca279ce36314f8580c0067acce..66a0693a59a529765b6d156f9179bb937631a8f9 100644
--- a/R/pkg/tests/fulltests/test_mllib_tree.R
+++ b/R/pkg/tests/fulltests/test_mllib_tree.R
@@ -212,6 +212,23 @@ test_that("spark.randomForest", {
   expect_equal(length(grep("1.0", predictions)), 50)
   expect_equal(length(grep("2.0", predictions)), 50)
 
+  # Test unseen labels
+  data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE),
+                    someString = base::sample(c("this", "that"), 10, replace = TRUE),
+                    stringsAsFactors = FALSE)
+  trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
+  traindf <- as.DataFrame(data[trainidxs, ])
+  testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
+  model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
+                          maxDepth = 10, maxBins = 10, numTrees = 10)
+  predictions <- predict(model, testdf)
+  expect_error(collect(predictions))
+  model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
+                             maxDepth = 10, maxBins = 10, numTrees = 10,
+                             handleInvalid = "skip")
+  predictions <- predict(model, testdf)
+  expect_equal(class(collect(predictions)$clicked[1]), "character")
+
   # spark.randomForest classification can work on libsvm data
   if (windows_with_hadoop()) {
     data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 4b44878784c9055494db57c0acddce74ebc641c2..61aa6463bb6daec9b5da801058aaef0d9194c127 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -132,6 +132,30 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
   @Since("1.5.0")
   def getFormula: String = $(formula)
 
+  /**
+   * Param for how to handle invalid data (unseen labels or NULL values).
+   * Options are 'skip' (filter out rows with invalid data),
+   * 'error' (throw an error), or 'keep' (put invalid data in a special additional
+   * bucket, at index numLabels).
+   * Default: "error"
+   * @group param
+   */
+  @Since("2.3.0")
+  val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle " +
+    "invalid data (unseen labels or NULL values). " +
+    "Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
+    "or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
+    ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
+  setDefault(handleInvalid, StringIndexer.ERROR_INVALID)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
+  /** @group getParam */
+  @Since("2.3.0")
+  def getHandleInvalid: String = $(handleInvalid)
+
   /** @group setParam */
   @Since("1.5.0")
   def setFeaturesCol(value: String): this.type = set(featuresCol, value)
@@ -197,6 +221,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
             .setInputCol(term)
             .setOutputCol(indexCol)
             .setStringOrderType($(stringIndexerOrderType))
+            .setHandleInvalid($(handleInvalid))
           prefixesToRewrite(indexCol + "_") = term + "_"
           (term, indexCol)
         case _ =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
index 8a83d4e980f7b496b668f177549623ae6c6a0ba0..132345fb9a6d9ae74826de58a7479f51aeb593c5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala
@@ -78,11 +78,13 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
       seed: String,
       subsamplingRate: Double,
       maxMemoryInMB: Int,
-      cacheNodeIds: Boolean): RandomForestClassifierWrapper = {
+      cacheNodeIds: Boolean,
+      handleInvalid: String): RandomForestClassifierWrapper = {
 
     val rFormula = new RFormula()
       .setFormula(formula)
       .setForceIndexLabel(true)
+      .setHandleInvalid(handleInvalid)
     checkDataColumns(rFormula, data)
     val rFormulaModel = rFormula.fit(data)
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 806a92760c8b6b2ee5fc65d29adef183d0579149..027b1fbc6657cd2b918086d15a649187b0aa2f56 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
 import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.functions.col