From 0ba9ecbea88533b2562f2f6045eafeab99d8f0c6 Mon Sep 17 00:00:00 2001
From: Bryan Cutler <cutlerb@gmail.com>
Date: Tue, 7 Mar 2017 20:44:30 -0800
Subject: [PATCH] [SPARK-19348][PYTHON] PySpark keyword_only decorator is not
 thread-safe

## What changes were proposed in this pull request?
The `keyword_only` decorator in PySpark is not thread-safe.  It writes kwargs to a static class variable in the decorator, which is then retrieved later in the class method as `_input_kwargs`.  If multiple threads are constructing the same class with different kwargs, it becomes a race condition to read from the static class variable before it's overwritten.  See [SPARK-19348](https://issues.apache.org/jira/browse/SPARK-19348) for reproduction code.

This change will write the kwargs to a member variable so that multiple threads can operate on separate instances without the race condition.  It does not protect against multiple threads operating on a single instance, but that is better left to the user to synchronize.

## How was this patch tested?
Added new unit tests for using the keyword_only decorator and a regression test that verifies `_input_kwargs` can be overwritten from different class instances.

Author: Bryan Cutler <cutlerb@gmail.com>

Closes #17193 from BryanCutler/pyspark-keyword_only-threadsafe-SPARK-19348-2_1.
---
 python/pyspark/__init__.py          |  10 ++-
 python/pyspark/ml/classification.py |  28 +++----
 python/pyspark/ml/clustering.py     |  16 ++--
 python/pyspark/ml/evaluation.py     |  12 +--
 python/pyspark/ml/feature.py        | 112 ++++++++++++++--------------
 python/pyspark/ml/pipeline.py       |   4 +-
 python/pyspark/ml/recommendation.py |   4 +-
 python/pyspark/ml/regression.py     |  28 +++----
 python/pyspark/ml/tests.py          |   8 +-
 python/pyspark/ml/tuning.py         |   8 +-
 python/pyspark/tests.py             |  39 ++++++++++
 11 files changed, 155 insertions(+), 114 deletions(-)

diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 5f93586a48..f7927b38e5 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -90,13 +90,15 @@ def keyword_only(func):
     """
     A decorator that forces keyword arguments in the wrapped method
     and saves actual input keyword arguments in `_input_kwargs`.
+
+    .. note:: Should only be used to wrap a method where first arg is `self`
     """
     @wraps(func)
-    def wrapper(*args, **kwargs):
-        if len(args) > 1:
+    def wrapper(self, *args, **kwargs):
+        if len(args) > 0:
             raise TypeError("Method %s forces keyword arguments." % func.__name__)
-        wrapper._input_kwargs = kwargs
-        return func(*args, **kwargs)
+        self._input_kwargs = kwargs
+        return func(self, **kwargs)
     return wrapper
 
 
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 5fe4bab186..570a414cc3 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -152,7 +152,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.classification.LogisticRegression", self.uid)
         self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5, family="auto")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
         self._checkThresholdConsistency()
 
@@ -172,7 +172,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
         Sets params for logistic regression.
         If the threshold and thresholds Params are both set, they must be equivalent.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         self._set(**kwargs)
         self._checkThresholdConsistency()
         return self
@@ -646,7 +646,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
         self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
                          maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
                          impurity="gini")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -664,7 +664,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
                   seed=None)
         Sets params for the DecisionTreeClassifier.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
@@ -776,7 +776,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
                          maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
                          impurity="gini", numTrees=20, featureSubsetStrategy="auto",
                          subsamplingRate=1.0)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -794,7 +794,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
                   impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0)
         Sets params for linear classification.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
@@ -917,7 +917,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
         self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
                          maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
                          lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -933,7 +933,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
                   lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0)
         Sets params for Gradient Boosted Tree Classification.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
@@ -1060,7 +1060,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.classification.NaiveBayes", self.uid)
         self._setDefault(smoothing=1.0, modelType="multinomial")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1074,7 +1074,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
                   modelType="multinomial", thresholds=None, weightCol=None)
         Sets params for Naive Bayes.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
@@ -1215,7 +1215,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid)
         self._setDefault(maxIter=100, tol=1E-4, blockSize=128, stepSize=0.03, solver="l-bfgs")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1229,7 +1229,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
                   solver="l-bfgs", initialWeights=None)
         Sets params for MultilayerPerceptronClassifier.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
@@ -1400,7 +1400,7 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
                  classifier=None)
         """
         super(OneVsRest, self).__init__()
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self._set(**kwargs)
 
     @keyword_only
@@ -1410,7 +1410,7 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
         setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None):
         Sets params for OneVsRest.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _fit(self, dataset):
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 35d0aefa04..86aa28905c 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -232,7 +232,7 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
         self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.GaussianMixture",
                                             self.uid)
         self._setDefault(k=2, tol=0.01, maxIter=100)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     def _create_model(self, java_model):
@@ -248,7 +248,7 @@ class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
 
         Sets params for GaussianMixture.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("2.0.0")
@@ -414,7 +414,7 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
         super(KMeans, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid)
         self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     def _create_model(self, java_model):
@@ -430,7 +430,7 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
 
         Sets params for KMeans.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.5.0")
@@ -591,7 +591,7 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
         self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.BisectingKMeans",
                                             self.uid)
         self._setDefault(maxIter=20, k=4, minDivisibleClusterSize=1.0)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -603,7 +603,7 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte
                   seed=None, k=4, minDivisibleClusterSize=1.0)
         Sets params for BisectingKMeans.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("2.0.0")
@@ -916,7 +916,7 @@ class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInter
                          k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51,
                          subsamplingRate=0.05, optimizeDocConcentration=True,
                          topicDistributionCol="topicDistribution", keepLastCheckpoint=True)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     def _create_model(self, java_model):
@@ -941,7 +941,7 @@ class LDA(JavaEstimator, HasFeaturesCol, HasMaxIter, HasSeed, HasCheckpointInter
 
         Sets params for LDA.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("2.0.0")
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 7aa16fa5b9..7cb8d62f21 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -148,7 +148,7 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
             "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid)
         self._setDefault(rawPredictionCol="rawPrediction", labelCol="label",
                          metricName="areaUnderROC")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self._set(**kwargs)
 
     @since("1.4.0")
@@ -174,7 +174,7 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
                   metricName="areaUnderROC")
         Sets params for binary classification evaluator.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
 
@@ -226,7 +226,7 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
             "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid)
         self._setDefault(predictionCol="prediction", labelCol="label",
                          metricName="rmse")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self._set(**kwargs)
 
     @since("1.4.0")
@@ -252,7 +252,7 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
                   metricName="rmse")
         Sets params for regression evaluator.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
 
@@ -299,7 +299,7 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
             "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid)
         self._setDefault(predictionCol="prediction", labelCol="label",
                          metricName="f1")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self._set(**kwargs)
 
     @since("1.5.0")
@@ -325,7 +325,7 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
                   metricName="f1")
         Sets params for multiclass classification evaluator.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 62c31431b5..3a4b6ed6a3 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -92,7 +92,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Java
         super(Binarizer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid)
         self._setDefault(threshold=0.0)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -102,7 +102,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Java
         setParams(self, threshold=0.0, inputCol=None, outputCol=None)
         Sets params for this Binarizer.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.4.0")
@@ -178,7 +178,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
         super(Bucketizer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid)
         self._setDefault(handleInvalid="error")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -188,7 +188,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
         setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
         Sets params for this Bucketizer.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.4.0")
@@ -292,7 +292,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer",
                                             self.uid)
         self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -304,7 +304,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
                   outputCol=None)
         Set the params for the CountVectorizer
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.6.0")
@@ -424,7 +424,7 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWrit
         super(DCT, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.DCT", self.uid)
         self._setDefault(inverse=False)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -434,7 +434,7 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWrit
         setParams(self, inverse=False, inputCol=None, outputCol=None)
         Sets params for this DCT.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.6.0")
@@ -488,7 +488,7 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReada
         super(ElementwiseProduct, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ElementwiseProduct",
                                             self.uid)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -498,7 +498,7 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReada
         setParams(self, scalingVec=None, inputCol=None, outputCol=None)
         Sets params for this ElementwiseProduct.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("2.0.0")
@@ -558,7 +558,7 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java
         super(HashingTF, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid)
         self._setDefault(numFeatures=1 << 18, binary=False)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -568,7 +568,7 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java
         setParams(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None)
         Sets params for this HashingTF.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("2.0.0")
@@ -631,7 +631,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab
         super(IDF, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid)
         self._setDefault(minDocFreq=0)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -641,7 +641,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab
         setParams(self, minDocFreq=0, inputCol=None, outputCol=None)
         Sets params for this IDF.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.4.0")
@@ -721,7 +721,7 @@ class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Jav
         super(MaxAbsScaler, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MaxAbsScaler", self.uid)
         self._setDefault()
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -731,7 +731,7 @@ class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Jav
         setParams(self, inputCol=None, outputCol=None)
         Sets params for this MaxAbsScaler.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
@@ -815,7 +815,7 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Jav
         super(MinMaxScaler, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid)
         self._setDefault(min=0.0, max=1.0)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -825,7 +825,7 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Jav
         setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None)
         Sets params for this MinMaxScaler.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.6.0")
@@ -933,7 +933,7 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWr
         super(NGram, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid)
         self._setDefault(n=2)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -943,7 +943,7 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWr
         setParams(self, n=2, inputCol=None, outputCol=None)
         Sets params for this NGram.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.5.0")
@@ -997,7 +997,7 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
         super(Normalizer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid)
         self._setDefault(p=2.0)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1007,7 +1007,7 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
         setParams(self, p=2.0, inputCol=None, outputCol=None)
         Sets params for this Normalizer.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.4.0")
@@ -1077,7 +1077,7 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
         super(OneHotEncoder, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid)
         self._setDefault(dropLast=True)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1087,7 +1087,7 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
         setParams(self, dropLast=True, inputCol=None, outputCol=None)
         Sets params for this OneHotEncoder.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.4.0")
@@ -1143,7 +1143,7 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.feature.PolynomialExpansion", self.uid)
         self._setDefault(degree=2)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1153,7 +1153,7 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead
         setParams(self, degree=2, inputCol=None, outputCol=None)
         Sets params for this PolynomialExpansion.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.4.0")
@@ -1239,7 +1239,7 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer",
                                             self.uid)
         self._setDefault(numBuckets=2, relativeError=0.001, handleInvalid="error")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1251,7 +1251,7 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab
                   handleInvalid="error")
         Set the params for the QuantileDiscretizer
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("2.0.0")
@@ -1364,7 +1364,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
         super(RegexTokenizer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid)
         self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+", toLowercase=True)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1376,7 +1376,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
                   outputCol=None, toLowercase=True)
         Sets params for this RegexTokenizer.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.4.0")
@@ -1467,7 +1467,7 @@ class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable):
         """
         super(SQLTransformer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1477,7 +1477,7 @@ class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable):
         setParams(self, statement=None)
         Sets params for this SQLTransformer.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.6.0")
@@ -1546,7 +1546,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, J
         super(StandardScaler, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid)
         self._setDefault(withMean=False, withStd=True)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1556,7 +1556,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, J
         setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None)
         Sets params for this StandardScaler.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.4.0")
@@ -1662,7 +1662,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
         super(StringIndexer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
         self._setDefault(handleInvalid="error")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1672,7 +1672,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
         setParams(self, inputCol=None, outputCol=None, handleInvalid="error")
         Sets params for this StringIndexer.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
@@ -1720,7 +1720,7 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
         super(IndexToString, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString",
                                             self.uid)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1730,7 +1730,7 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
         setParams(self, inputCol=None, outputCol=None, labels=None)
         Sets params for this IndexToString.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.6.0")
@@ -1784,7 +1784,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
                                             self.uid)
         self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"),
                          caseSensitive=False)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1794,7 +1794,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
         setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false)
         Sets params for this StopWordRemover.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.6.0")
@@ -1877,7 +1877,7 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Java
         """
         super(Tokenizer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Tokenizer", self.uid)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1887,7 +1887,7 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Java
         setParams(self, inputCol=None, outputCol=None)
         Sets params for this Tokenizer.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
 
@@ -1921,7 +1921,7 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl
         """
         super(VectorAssembler, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1931,7 +1931,7 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl
         setParams(self, inputCols=None, outputCol=None)
         Sets params for this VectorAssembler.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
 
@@ -2019,7 +2019,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
         super(VectorIndexer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid)
         self._setDefault(maxCategories=20)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -2029,7 +2029,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
         setParams(self, maxCategories=20, inputCol=None, outputCol=None)
         Sets params for this VectorIndexer.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.4.0")
@@ -2134,7 +2134,7 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, J
         super(VectorSlicer, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid)
         self._setDefault(indices=[], names=[])
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -2144,7 +2144,7 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, J
         setParams(self, inputCol=None, outputCol=None, indices=None, names=None):
         Sets params for this VectorSlicer.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.6.0")
@@ -2257,7 +2257,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid)
         self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
                          windowSize=5, maxSentenceLength=1000)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -2269,7 +2269,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
                  inputCol=None, outputCol=None, windowSize=5, maxSentenceLength=1000)
         Sets params for this Word2Vec.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.4.0")
@@ -2417,7 +2417,7 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab
         """
         super(PCA, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.PCA", self.uid)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -2427,7 +2427,7 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab
         setParams(self, k=None, inputCol=None, outputCol=None)
         Set params for this PCA.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.5.0")
@@ -2557,7 +2557,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
         super(RFormula, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
         self._setDefault(forceIndexLabel=False)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -2569,7 +2569,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
                   forceIndexLabel=False)
         Sets params for RFormula.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.5.0")
@@ -2687,7 +2687,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid)
         self._setDefault(numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1,
                          fpr=0.05)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -2699,7 +2699,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
                   labelCol="labels", selectorType="numTopFeatures", percentile=0.1, fpr=0.05)
         Sets params for this ChiSqSelector.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("2.1.0")
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 4307ad02a0..2d2e4c13e8 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -58,7 +58,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
         __init__(self, stages=None)
         """
         super(Pipeline, self).__init__()
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @since("1.3.0")
@@ -85,7 +85,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
         setParams(self, stages=None)
         Sets params for Pipeline.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _fit(self, dataset):
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index e28d38bd19..ee9916f472 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -146,7 +146,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
                          ratingCol="rating", nonnegative=False, checkpointInterval=10,
                          intermediateStorageLevel="MEMORY_AND_DISK",
                          finalStorageLevel="MEMORY_AND_DISK")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -164,7 +164,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
                  finalStorageLevel="MEMORY_AND_DISK")
         Sets params for ALS.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index b42e807069..b199bf282e 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -108,7 +108,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.regression.LinearRegression", self.uid)
         self._setDefault(maxIter=100, regParam=0.0, tol=1e-6)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -122,7 +122,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
                   standardization=True, solver="auto", weightCol=None, aggregationDepth=2)
         Sets params for linear regression.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
@@ -464,7 +464,7 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.regression.IsotonicRegression", self.uid)
         self._setDefault(isotonic=True, featureIndex=0)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -475,7 +475,7 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
                  weightCol=None, isotonic=True, featureIndex=0):
         Set the params for IsotonicRegression.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
@@ -704,7 +704,7 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
         self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
                          maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
                          impurity="variance")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -720,7 +720,7 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
                   impurity="variance", seed=None, varianceCol=None)
         Sets params for the DecisionTreeRegressor.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
@@ -895,7 +895,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
                          maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
                          impurity="variance", subsamplingRate=1.0, numTrees=20,
                          featureSubsetStrategy="auto")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -913,7 +913,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
                   featureSubsetStrategy="auto")
         Sets params for linear regression.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
@@ -1022,7 +1022,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
                          maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
                          checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1,
                          impurity="variance")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1040,7 +1040,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
                   impurity="variance")
         Sets params for Gradient Boosted Tree Regression.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
@@ -1171,7 +1171,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
         self._setDefault(censorCol="censor",
                          quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99],
                          maxIter=100, tol=1E-6)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1186,7 +1186,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
                   quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
                   quantilesCol=None, aggregationDepth=2):
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
@@ -1366,7 +1366,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
         self._java_obj = self._new_java_obj(
             "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
         self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls")
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -1380,7 +1380,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha
                   regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None)
         Sets params for generalized linear regression.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     def _create_model(self, java_model):
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 46be031ee8..70e0c6de4a 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -250,7 +250,7 @@ class TestParams(HasMaxIter, HasInputCol, HasSeed):
     def __init__(self, seed=None):
         super(TestParams, self).__init__()
         self._setDefault(maxIter=10)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -259,7 +259,7 @@ class TestParams(HasMaxIter, HasInputCol, HasSeed):
         setParams(self, seed=None)
         Sets params for this test.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
 
@@ -271,7 +271,7 @@ class OtherTestParams(HasMaxIter, HasInputCol, HasSeed):
     def __init__(self, seed=None):
         super(OtherTestParams, self).__init__()
         self._setDefault(maxIter=10)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
@@ -280,7 +280,7 @@ class OtherTestParams(HasMaxIter, HasInputCol, HasSeed):
         setParams(self, seed=None)
         Sets params for this test.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
 
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 2dcc99cef8..ffeb4459e1 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -186,7 +186,7 @@ class CrossValidator(Estimator, ValidatorParams):
         """
         super(CrossValidator, self).__init__()
         self._setDefault(numFolds=3)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self._set(**kwargs)
 
     @keyword_only
@@ -198,7 +198,7 @@ class CrossValidator(Estimator, ValidatorParams):
                   seed=None):
         Sets params for cross validator.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("1.4.0")
@@ -346,7 +346,7 @@ class TrainValidationSplit(Estimator, ValidatorParams):
         """
         super(TrainValidationSplit, self).__init__()
         self._setDefault(trainRatio=0.75)
-        kwargs = self.__init__._input_kwargs
+        kwargs = self._input_kwargs
         self._set(**kwargs)
 
     @since("2.0.0")
@@ -358,7 +358,7 @@ class TrainValidationSplit(Estimator, ValidatorParams):
                   seed=None):
         Sets params for the train validation split.
         """
-        kwargs = self.setParams._input_kwargs
+        kwargs = self._input_kwargs
         return self._set(**kwargs)
 
     @since("2.0.0")
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 8e35a4ee8e..1df91ad956 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -58,6 +58,7 @@ else:
     from StringIO import StringIO
 
 
+from pyspark import keyword_only
 from pyspark.conf import SparkConf
 from pyspark.context import SparkContext
 from pyspark.rdd import RDD
@@ -2095,6 +2096,44 @@ class ConfTests(unittest.TestCase):
             sc.stop()
 
 
+class KeywordOnlyTests(unittest.TestCase):
+    class Wrapped(object):
+        @keyword_only
+        def set(self, x=None, y=None):
+            if "x" in self._input_kwargs:
+                self._x = self._input_kwargs["x"]
+            if "y" in self._input_kwargs:
+                self._y = self._input_kwargs["y"]
+            return x, y
+
+    def test_keywords(self):
+        w = self.Wrapped()
+        x, y = w.set(y=1)
+        self.assertEqual(y, 1)
+        self.assertEqual(y, w._y)
+        self.assertIsNone(x)
+        self.assertFalse(hasattr(w, "_x"))
+
+    def test_non_keywords(self):
+        w = self.Wrapped()
+        self.assertRaises(TypeError, lambda: w.set(0, y=1))
+
+    def test_kwarg_ownership(self):
+        # test _input_kwargs is owned by each class instance and not a shared static variable
+        class Setter(object):
+            @keyword_only
+            def set(self, x=None, other=None, other_x=None):
+                if "other" in self._input_kwargs:
+                    self._input_kwargs["other"].set(x=self._input_kwargs["other_x"])
+                self._x = self._input_kwargs["x"]
+
+        a = Setter()
+        b = Setter()
+        a.set(x=1, other=b, other_x=2)
+        self.assertEqual(a._x, 1)
+        self.assertEqual(b._x, 2)
+
+
 @unittest.skipIf(not _have_scipy, "SciPy not installed")
 class SciPyTests(PySparkTestCase):
 
-- 
GitLab