diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 97c86552986096f61c2549b15fa2d190e23139c4..af007625d1827cc5c3c9da517de643a2c51d69c1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -90,6 +90,27 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo
       n.toInt
     }
   }
+
+  /**
+   * Param for strategy for dealing with unknown or new users/items at prediction time.
+   * This may be useful in cross-validation or production scenarios, for handling user/item ids
+   * the model has not seen in the training data.
+   * Supported values:
+   * - "nan":  predicted value for unknown ids will be NaN.
+   * - "drop": rows in the input DataFrame containing unknown ids will be dropped from
+   *           the output DataFrame containing predictions.
+   * Default: "nan".
+   * @group expertParam
+   */
+  val coldStartStrategy = new Param[String](this, "coldStartStrategy",
+    "strategy for dealing with unknown or new users/items at prediction time. This may be " +
+    "useful in cross-validation or production scenarios, for handling user/item ids the model " +
+    "has not seen in the training data. Supported values: " +
+    s"${ALSModel.supportedColdStartStrategies.mkString(",")}.",
+    (s: String) => ALSModel.supportedColdStartStrategies.contains(s.toLowerCase))
+
+  /** @group expertGetParam */
+  def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase
 }
 
 /**
@@ -203,7 +224,8 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
   setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
     implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
     ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
-    intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK")
+    intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK",
+    coldStartStrategy -> "nan")
 
   /**
    * Validates and transforms the input schema.
@@ -248,6 +270,10 @@ class ALSModel private[ml] (
   @Since("1.3.0")
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
+  /** @group expertSetParam */
+  @Since("2.2.0")
+  def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)
+
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema)
@@ -260,13 +286,19 @@ class ALSModel private[ml] (
         Float.NaN
       }
     }
-    dataset
+    val predictions = dataset
       .join(userFactors,
         checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left")
       .join(itemFactors,
         checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left")
       .select(dataset("*"),
         predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
+    getColdStartStrategy match {
+      case ALSModel.Drop =>
+        predictions.na.drop("all", Seq($(predictionCol)))
+      case ALSModel.NaN =>
+        predictions
+    }
   }
 
   @Since("1.3.0")
@@ -290,6 +322,10 @@ class ALSModel private[ml] (
 @Since("1.6.0")
 object ALSModel extends MLReadable[ALSModel] {
 
+  private val NaN = "nan"
+  private val Drop = "drop"
+  private[recommendation] final val supportedColdStartStrategies = Array(NaN, Drop)
+
   @Since("1.6.0")
   override def read: MLReader[ALSModel] = new ALSModelReader
 
@@ -432,6 +468,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
   @Since("2.0.0")
   def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value)
 
+  /** @group expertSetParam */
+  @Since("2.2.0")
+  def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)
+
   /**
    * Sets both numUserBlocks and numItemBlocks to the specific value.
    *
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index b923bacce23ca813defbfc5ad105d29ab1f8b5d5..c9e7b505b2bd2a69f4efee3f67e65c608d452d57 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -498,8 +498,8 @@ class ALSSuite
           (ex, act) =>
             ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1)
         } { (ex, act, _) =>
-          ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~==
-            act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6
+          ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~==
+            act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6
         }
     }
     // check user/item ids falling outside of Int range
@@ -547,6 +547,53 @@ class ALSSuite
       ALS.train(ratings)
     }
   }
+
+  test("ALS cold start user/item prediction strategy") {
+    val spark = this.spark
+    import spark.implicits._
+    import org.apache.spark.sql.functions._
+
+    val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
+    val data = ratings.toDF
+    val knownUser = data.select(max("user")).as[Int].first()
+    val unknownUser = knownUser + 10
+    val knownItem = data.select(max("item")).as[Int].first()
+    val unknownItem = knownItem + 20
+    val test = Seq(
+      (unknownUser, unknownItem),
+      (knownUser, unknownItem),
+      (unknownUser, knownItem),
+      (knownUser, knownItem)
+    ).toDF("user", "item")
+
+    val als = new ALS().setMaxIter(1).setRank(1)
+    // default is 'nan'
+    val defaultModel = als.fit(data)
+    val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect()
+    assert(defaultPredictions.length == 4)
+    assert(defaultPredictions.slice(0, 3).forall(_.isNaN))
+    assert(!defaultPredictions.last.isNaN)
+
+    // check 'drop' strategy should filter out rows with unknown users/items
+    val dropPredictions = defaultModel
+      .setColdStartStrategy("drop")
+      .transform(test)
+      .select("prediction").as[Float].collect()
+    assert(dropPredictions.length == 1)
+    assert(!dropPredictions.head.isNaN)
+    assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14)
+  }
+
+  test("case insensitive cold start param value") {
+    val spark = this.spark
+    import spark.implicits._
+    val (ratings, _) = genExplicitTestData(numUsers = 2, numItems = 2, rank = 1)
+    val data = ratings.toDF
+    val model = new ALS().fit(data)
+    Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s =>
+      model.setColdStartStrategy(s).transform(data)
+    }
+  }
 }
 
 class ALSCleanerSuite extends SparkFunSuite {
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index e28d38bd19f80b12f0cc66780f548a235f2f1f12..43f82daa9fcd3adea77792aabf6e50712686a118 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -125,19 +125,25 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
     finalStorageLevel = Param(Params._dummy(), "finalStorageLevel",
                               "StorageLevel for ALS model factors.",
                               typeConverter=TypeConverters.toString)
+    coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " +
+                              "unknown or new users/items at prediction time. This may be useful " +
+                              "in cross-validation or production scenarios, for handling " +
+                              "user/item ids the model has not seen in the training data. " +
+                              "Supported values: 'nan', 'drop'.",
+                              typeConverter=TypeConverters.toString)
 
     @keyword_only
     def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
                  implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
                  ratingCol="rating", nonnegative=False, checkpointInterval=10,
                  intermediateStorageLevel="MEMORY_AND_DISK",
-                 finalStorageLevel="MEMORY_AND_DISK"):
+                 finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"):
         """
         __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
                  implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \
                  ratingCol="rating", nonnegative=false, checkpointInterval=10, \
                  intermediateStorageLevel="MEMORY_AND_DISK", \
-                 finalStorageLevel="MEMORY_AND_DISK")
+                 finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
         """
         super(ALS, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
@@ -145,7 +151,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
                          implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
                          ratingCol="rating", nonnegative=False, checkpointInterval=10,
                          intermediateStorageLevel="MEMORY_AND_DISK",
-                         finalStorageLevel="MEMORY_AND_DISK")
+                         finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
         kwargs = self.__init__._input_kwargs
         self.setParams(**kwargs)
 
@@ -155,13 +161,13 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
                   implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
                   ratingCol="rating", nonnegative=False, checkpointInterval=10,
                   intermediateStorageLevel="MEMORY_AND_DISK",
-                  finalStorageLevel="MEMORY_AND_DISK"):
+                  finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"):
         """
         setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
                  implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \
                  ratingCol="rating", nonnegative=False, checkpointInterval=10, \
                  intermediateStorageLevel="MEMORY_AND_DISK", \
-                 finalStorageLevel="MEMORY_AND_DISK")
+                 finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
         Sets params for ALS.
         """
         kwargs = self.setParams._input_kwargs
@@ -332,6 +338,20 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
         """
         return self.getOrDefault(self.finalStorageLevel)
 
+    @since("2.2.0")
+    def setColdStartStrategy(self, value):
+        """
+        Sets the value of :py:attr:`coldStartStrategy`.
+        """
+        return self._set(coldStartStrategy=value)
+
+    @since("2.2.0")
+    def getColdStartStrategy(self):
+        """
+        Gets the value of coldStartStrategy or its default value.
+        """
+        return self.getOrDefault(self.coldStartStrategy)
+
 
 class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
     """