diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 0c6a37bab0aadcafe758116ccf7bf331e9e0c7b8..9c131a41850cc827d8ff05fd73c1814d4c8698e9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -27,7 +27,7 @@ import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.feature.ChiSqSelectorType
+import org.apache.spark.mllib.feature.{ChiSqSelector => OldChiSqSelector}
 import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.rdd.RDD
@@ -44,7 +44,9 @@ private[feature] trait ChiSqSelectorParams extends Params
   /**
    * Number of features that selector will select (ordered by statistic value descending). If the
    * number of features is less than numTopFeatures, then this will select all features.
+   * Only applicable when selectorType = "kbest".
    * The default value of numTopFeatures is 50.
+   *
    * @group param
    */
   final val numTopFeatures = new IntParam(this, "numTopFeatures",
@@ -56,6 +58,11 @@ private[feature] trait ChiSqSelectorParams extends Params
   /** @group getParam */
   def getNumTopFeatures: Int = $(numTopFeatures)
 
+  /**
+   * Percentile of features that selector will select, ordered by statistics value descending.
+   * Only applicable when selectorType = "percentile".
+   * Default value is 0.1.
+   */
   final val percentile = new DoubleParam(this, "percentile",
     "Percentile of features that selector will select, ordered by statistics value descending.",
     ParamValidators.inRange(0, 1))
@@ -64,8 +71,12 @@ private[feature] trait ChiSqSelectorParams extends Params
   /** @group getParam */
   def getPercentile: Double = $(percentile)
 
-  final val alpha = new DoubleParam(this, "alpha",
-    "The highest p-value for features to be kept.",
+  /**
+   * The highest p-value for features to be kept.
+   * Only applicable when selectorType = "fpr".
+   * Default value is 0.05.
+   */
+  final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.",
     ParamValidators.inRange(0, 1))
   setDefault(alpha -> 0.05)
 
@@ -73,29 +84,27 @@ private[feature] trait ChiSqSelectorParams extends Params
   def getAlpha: Double = $(alpha)
 
   /**
-   * The ChiSqSelector supports KBest, Percentile, FPR selection,
-   * which is the same as ChiSqSelectorType defined in MLLIB.
-   * when call setNumTopFeatures, the selectorType is set to KBest
-   * when call setPercentile, the selectorType is set to Percentile
-   * when call setAlpha, the selectorType is set to FPR
+   * The selector type of the ChisqSelector.
+   * Supported options: "kbest" (default), "percentile" and "fpr".
    */
   final val selectorType = new Param[String](this, "selectorType",
-    "ChiSqSelector Type: KBest, Percentile, FPR")
-  setDefault(selectorType -> ChiSqSelectorType.KBest.toString)
+    "The selector type of the ChisqSelector. " +
+      "Supported options: kbest (default), percentile and fpr.",
+    ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray))
+  setDefault(selectorType -> OldChiSqSelector.KBest)
 
   /** @group getParam */
-  def getChiSqSelectorType: String = $(selectorType)
+  def getSelectorType: String = $(selectorType)
 }
 
 /**
  * Chi-Squared feature selection, which selects categorical features to use for predicting a
  * categorical label.
- * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
- * `KBest` chooses the `k` top features according to a chi-squared test.
- * `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
- * `FPR` chooses all features whose false positive rate meets some threshold.
- * By default, the selection method is `KBest`, the default number of top features is 50.
- * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
+ * The selector supports three selection methods: `kbest`, `percentile` and `fpr`.
+ * `kbest` chooses the `k` top features according to a chi-squared test.
+ * `percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ * `fpr` chooses all features whose false positive rate meets some threshold.
+ * By default, the selection method is `kbest`, the default number of top features is 50.
  */
 @Since("1.6.0")
 final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String)
@@ -104,24 +113,21 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
   @Since("1.6.0")
   def this() = this(Identifiable.randomUID("chiSqSelector"))
 
+  /** @group setParam */
+  @Since("2.1.0")
+  def setSelectorType(value: String): this.type = set(selectorType, value)
+
   /** @group setParam */
   @Since("1.6.0")
-  def setNumTopFeatures(value: Int): this.type = {
-    set(selectorType, ChiSqSelectorType.KBest.toString)
-    set(numTopFeatures, value)
-  }
+  def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value)
 
+  /** @group setParam */
   @Since("2.1.0")
-  def setPercentile(value: Double): this.type = {
-    set(selectorType, ChiSqSelectorType.Percentile.toString)
-    set(percentile, value)
-  }
+  def setPercentile(value: Double): this.type = set(percentile, value)
 
+  /** @group setParam */
   @Since("2.1.0")
-  def setAlpha(value: Double): this.type = {
-    set(selectorType, ChiSqSelectorType.FPR.toString)
-    set(alpha, value)
-  }
+  def setAlpha(value: Double): this.type = set(alpha, value)
 
   /** @group setParam */
   @Since("1.6.0")
@@ -143,23 +149,23 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str
         case Row(label: Double, features: Vector) =>
           OldLabeledPoint(label, OldVectors.fromML(features))
       }
-    var selector = new feature.ChiSqSelector()
-    ChiSqSelectorType.withName($(selectorType)) match {
-      case ChiSqSelectorType.KBest =>
-        selector.setNumTopFeatures($(numTopFeatures))
-      case ChiSqSelectorType.Percentile =>
-        selector.setPercentile($(percentile))
-      case ChiSqSelectorType.FPR =>
-        selector.setAlpha($(alpha))
-      case errorType =>
-        throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
-    }
+    val selector = new feature.ChiSqSelector()
+      .setSelectorType($(selectorType))
+      .setNumTopFeatures($(numTopFeatures))
+      .setPercentile($(percentile))
+      .setAlpha($(alpha))
     val model = selector.fit(input)
     copyValues(new ChiSqSelectorModel(uid, model).setParent(this))
   }
 
   @Since("1.6.0")
   override def transformSchema(schema: StructType): StructType = {
+    val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType))
+    otherPairs.foreach { case (_, paramName: String) =>
+      if (isSet(getParam(paramName))) {
+        logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.")
+      }
+    }
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     SchemaUtils.checkNumericType(schema, $(labelCol))
     SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 5cffbf0892888ecaa00851a141235bc343b3414c..904000f50d0a2cba3e364f54d2c1aefa9471cfde 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -629,35 +629,23 @@ private[python] class PythonMLLibAPI extends Serializable {
   }
 
   /**
-   * Java stub for ChiSqSelector.fit() when the seletion type is KBest. This stub returns a
+   * Java stub for ChiSqSelector.fit(). This stub returns a
    * handle to the Java object instead of the content of the Java object.
    * Extra care needs to be taken in the Python code to ensure it gets freed on
    * exit; see the Py4J documentation.
    */
-  def fitChiSqSelectorKBest(numTopFeatures: Int,
-    data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
-    new ChiSqSelector().setNumTopFeatures(numTopFeatures).fit(data.rdd)
-  }
-
-  /**
-   * Java stub for ChiSqSelector.fit() when the selection type is Percentile. This stub returns a
-   * handle to the Java object instead of the content of the Java object.
-   * Extra care needs to be taken in the Python code to ensure it gets freed on
-   * exit; see the Py4J documentation.
-   */
-  def fitChiSqSelectorPercentile(percentile: Double,
-    data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
-    new ChiSqSelector().setPercentile(percentile).fit(data.rdd)
-  }
-
-  /**
-   * Java stub for ChiSqSelector.fit() when the selection type is FPR. This stub returns a
-   * handle to the Java object instead of the content of the Java object.
-   * Extra care needs to be taken in the Python code to ensure it gets freed on
-   * exit; see the Py4J documentation.
-   */
-  def fitChiSqSelectorFPR(alpha: Double, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
-    new ChiSqSelector().setAlpha(alpha).fit(data.rdd)
+  def fitChiSqSelector(
+      selectorType: String,
+      numTopFeatures: Int,
+      percentile: Double,
+      alpha: Double,
+      data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = {
+    new ChiSqSelector()
+      .setSelectorType(selectorType)
+      .setNumTopFeatures(numTopFeatures)
+      .setPercentile(percentile)
+      .setAlpha(alpha)
+      .fit(data.rdd)
   }
 
   /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index f68a017184b21b87f0a157ddc4250e28dd4861a3..0f7c6e8bc04bb88891fe37f2b93cf69cfad1c56b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -32,12 +32,6 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.SparkContext
 import org.apache.spark.sql.{Row, SparkSession}
 
-@Since("2.1.0")
-private[spark] object ChiSqSelectorType extends Enumeration {
-  type SelectorType = Value
-  val KBest, Percentile, FPR = Value
-}
-
 /**
  * Chi Squared selector model.
  *
@@ -166,19 +160,18 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
 
 /**
  * Creates a ChiSquared feature selector.
- * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
- * `KBest` chooses the `k` top features according to a chi-squared test.
- * `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
- * `FPR` chooses all features whose false positive rate meets some threshold.
- * By default, the selection method is `KBest`, the default number of top features is 50.
- * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
+ * The selector supports three selection methods: `kbest`, `percentile` and `fpr`.
+ * `kbest` chooses the `k` top features according to a chi-squared test.
+ * `percentile` is similar but chooses a fraction of all features instead of a fixed number.
+ * `fpr` chooses all features whose false positive rate meets some threshold.
+ * By default, the selection method is `kbest`, the default number of top features is 50.
  */
 @Since("1.3.0")
 class ChiSqSelector @Since("2.1.0") () extends Serializable {
   var numTopFeatures: Int = 50
   var percentile: Double = 0.1
   var alpha: Double = 0.05
-  var selectorType = ChiSqSelectorType.KBest
+  var selectorType = ChiSqSelector.KBest
 
   /**
    * The is the same to call this() and setNumTopFeatures(numTopFeatures)
@@ -192,7 +185,6 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
   @Since("1.6.0")
   def setNumTopFeatures(value: Int): this.type = {
     numTopFeatures = value
-    selectorType = ChiSqSelectorType.KBest
     this
   }
 
@@ -200,7 +192,6 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
   def setPercentile(value: Double): this.type = {
     require(0.0 <= value && value <= 1.0, "Percentile must be in [0,1]")
     percentile = value
-    selectorType = ChiSqSelectorType.Percentile
     this
   }
 
@@ -208,12 +199,13 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
   def setAlpha(value: Double): this.type = {
     require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]")
     alpha = value
-    selectorType = ChiSqSelectorType.FPR
     this
   }
 
   @Since("2.1.0")
-  def setChiSqSelectorType(value: ChiSqSelectorType.Value): this.type = {
+  def setSelectorType(value: String): this.type = {
+    require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value),
+      s"ChiSqSelector Type: $value was not supported.")
     selectorType = value
     this
   }
@@ -230,11 +222,11 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
     val chiSqTestResult = Statistics.chiSqTest(data)
       .zipWithIndex.sortBy { case (res, _) => -res.statistic }
     val features = selectorType match {
-      case ChiSqSelectorType.KBest => chiSqTestResult
+      case ChiSqSelector.KBest => chiSqTestResult
         .take(numTopFeatures)
-      case ChiSqSelectorType.Percentile => chiSqTestResult
+      case ChiSqSelector.Percentile => chiSqTestResult
         .take((chiSqTestResult.length * percentile).toInt)
-      case ChiSqSelectorType.FPR => chiSqTestResult
+      case ChiSqSelector.FPR => chiSqTestResult
         .filter{ case (res, _) => res.pValue < alpha }
       case errorType =>
         throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType")
@@ -244,3 +236,22 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable {
   }
 }
 
+@Since("2.1.0")
+object ChiSqSelector {
+
+  /** String name for `kbest` selector type. */
+  private[spark] val KBest: String = "kbest"
+
+  /** String name for `percentile` selector type. */
+  private[spark] val Percentile: String = "percentile"
+
+  /** String name for `fpr` selector type. */
+  private[spark] val FPR: String = "fpr"
+
+  /** Set of selector type and param pairs that ChiSqSelector supports. */
+  private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures",
+    Percentile -> "percentile", FPR -> "alpha")
+
+  /** Set of selector types that ChiSqSelector supports. */
+  private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1)
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index e0293dbc4b0b23f40468c7241b41bc3db1d662f9..6b56e4200250c7c3b1820002ff49c1d11ac3dd3e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -50,6 +50,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
       .toDF("label", "data", "preFilteredData")
 
     val selector = new ChiSqSelector()
+      .setSelectorType("kbest")
       .setNumTopFeatures(1)
       .setFeaturesCol("data")
       .setLabelCol("label")
@@ -60,12 +61,28 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
         assert(vec1 ~== vec2 absTol 1e-1)
     }
 
-    selector.setPercentile(0.34).fit(df).transform(df)
-    .select("filtered", "preFilteredData").collect().foreach {
-      case Row(vec1: Vector, vec2: Vector) =>
-        assert(vec1 ~== vec2 absTol 1e-1)
-    }
+    selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df)
+      .select("filtered", "preFilteredData").collect().foreach {
+        case Row(vec1: Vector, vec2: Vector) =>
+          assert(vec1 ~== vec2 absTol 1e-1)
+      }
+
+    val preFilteredData2 = Seq(
+      Vectors.dense(8.0, 7.0),
+      Vectors.dense(0.0, 9.0),
+      Vectors.dense(0.0, 9.0),
+      Vectors.dense(8.0, 9.0)
+    )
 
+    val df2 = sc.parallelize(data.zip(preFilteredData2))
+      .map(x => (x._1.label, x._1.features, x._2))
+      .toDF("label", "data", "preFilteredData")
+
+    selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2)
+      .select("filtered", "preFilteredData").collect().foreach {
+        case Row(vec1: Vector, vec2: Vector) =>
+          assert(vec1 ~== vec2 absTol 1e-1)
+      }
   }
 
   test("ChiSqSelector read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
index e181a544f7159f2c0c5ebcacfd4d00a1c457af0c..ec23a4aa7364dc15da9790ce5da4274f0ace0f1a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
@@ -76,7 +76,7 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
         LabeledPoint(1.0, Vectors.dense(Array(4.0))),
         LabeledPoint(1.0, Vectors.dense(Array(4.0))),
         LabeledPoint(2.0, Vectors.dense(Array(9.0))))
-    val model = new ChiSqSelector().setAlpha(0.1).fit(labeledDiscreteData)
+    val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData)
     val filteredData = labeledDiscreteData.map { lp =>
       LabeledPoint(lp.label, model.transform(lp.features))
     }.collect().toSet
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index c45434f1a57ca141553caab11eea3ee75996fdcb..12a13849dc9bc21c6dd546c18e775ab0b0de3050 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2586,39 +2586,68 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
     .. versionadded:: 2.0.0
     """
 
+    selectorType = Param(Params._dummy(), "selectorType",
+                         "The selector type of the ChisqSelector. " +
+                         "Supported options: kbest (default), percentile and fpr.",
+                         typeConverter=TypeConverters.toString)
+
     numTopFeatures = \
         Param(Params._dummy(), "numTopFeatures",
               "Number of features that selector will select, ordered by statistics value " +
               "descending. If the number of features is < numTopFeatures, then this will select " +
               "all features.", typeConverter=TypeConverters.toInt)
 
+    percentile = Param(Params._dummy(), "percentile", "Percentile of features that selector " +
+                       "will select, ordered by statistics value descending.",
+                       typeConverter=TypeConverters.toFloat)
+
+    alpha = Param(Params._dummy(), "alpha", "The highest p-value for features to be kept.",
+                  typeConverter=TypeConverters.toFloat)
+
     @keyword_only
-    def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"):
+    def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None,
+                 labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05):
         """
-        __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label")
+        __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, \
+                 labelCol="label", selectorType="kbest", percentile=0.1, alpha=0.05)
         """
         super(ChiSqSelector, self).__init__()
         self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid)
-        self._setDefault(numTopFeatures=50)
+        self._setDefault(numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05)
         kwargs = self.__init__._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("2.0.0")
     def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,
-                  labelCol="labels"):
+                  labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05):
         """
-        setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None,\
-                  labelCol="labels")
+        setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, \
+                  labelCol="labels", selectorType="kbest", percentile=0.1, alpha=0.05)
         Sets params for this ChiSqSelector.
         """
         kwargs = self.setParams._input_kwargs
         return self._set(**kwargs)
 
+    @since("2.1.0")
+    def setSelectorType(self, value):
+        """
+        Sets the value of :py:attr:`selectorType`.
+        """
+        return self._set(selectorType=value)
+
+    @since("2.1.0")
+    def getSelectorType(self):
+        """
+        Gets the value of selectorType or its default value.
+        """
+        return self.getOrDefault(self.selectorType)
+
     @since("2.0.0")
     def setNumTopFeatures(self, value):
         """
         Sets the value of :py:attr:`numTopFeatures`.
+        Only applicable when selectorType = "kbest".
         """
         return self._set(numTopFeatures=value)
 
@@ -2629,6 +2658,36 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
         """
         return self.getOrDefault(self.numTopFeatures)
 
+    @since("2.1.0")
+    def setPercentile(self, value):
+        """
+        Sets the value of :py:attr:`percentile`.
+        Only applicable when selectorType = "percentile".
+        """
+        return self._set(percentile=value)
+
+    @since("2.1.0")
+    def getPercentile(self):
+        """
+        Gets the value of percentile or its default value.
+        """
+        return self.getOrDefault(self.percentile)
+
+    @since("2.1.0")
+    def setAlpha(self, value):
+        """
+        Sets the value of :py:attr:`alpha`.
+        Only applicable when selectorType = "fpr".
+        """
+        return self._set(alpha=value)
+
+    @since("2.1.0")
+    def getAlpha(self):
+        """
+        Gets the value of alpha or its default value.
+        """
+        return self.getOrDefault(self.alpha)
+
     def _create_model(self, java_model):
         return ChiSqSelectorModel(java_model)
 
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 077c11370eb3f73e3b904e35f0a6fac6695bcdd8..4aea81840a16225d75339947e4ca56afbc2dbe0d 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -271,22 +271,14 @@ class ChiSqSelectorModel(JavaVectorTransformer):
         return JavaVectorTransformer.transform(self, vector)
 
 
-class ChiSqSelectorType:
-    """
-    This class defines the selector types of Chi Square Selector.
-    """
-    KBest, Percentile, FPR = range(3)
-
-
 class ChiSqSelector(object):
     """
     Creates a ChiSquared feature selector.
     The selector supports three selection methods: `KBest`, `Percentile` and `FPR`.
-    `KBest` chooses the `k` top features according to a chi-squared test.
-    `Percentile` is similar but chooses a fraction of all features instead of a fixed number.
-    `FPR` chooses all features whose false positive rate meets some threshold.
-    By default, the selection method is `KBest`, the default number of top features is 50.
-    User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods.
+    `kbest` chooses the `k` top features according to a chi-squared test.
+    `percentile` is similar but chooses a fraction of all features instead of a fixed number.
+    `fpr` chooses all features whose false positive rate meets some threshold.
+    By default, the selection method is `kbest`, the default number of top features is 50.
 
     >>> data = [
     ...     LabeledPoint(0.0, SparseVector(3, {0: 8.0, 1: 7.0})),
@@ -299,7 +291,8 @@ class ChiSqSelector(object):
     SparseVector(1, {0: 6.0})
     >>> model.transform(DenseVector([8.0, 9.0, 5.0]))
     DenseVector([5.0])
-    >>> model = ChiSqSelector().setPercentile(0.34).fit(sc.parallelize(data))
+    >>> model = ChiSqSelector().setSelectorType("percentile").setPercentile(0.34).fit(
+    ...     sc.parallelize(data))
     >>> model.transform(SparseVector(3, {1: 9.0, 2: 6.0}))
     SparseVector(1, {0: 6.0})
     >>> model.transform(DenseVector([8.0, 9.0, 5.0]))
@@ -310,41 +303,52 @@ class ChiSqSelector(object):
     ...     LabeledPoint(1.0, [0.0, 9.0, 8.0, 4.0]),
     ...     LabeledPoint(2.0, [8.0, 9.0, 5.0, 9.0])
     ... ]
-    >>> model = ChiSqSelector().setAlpha(0.1).fit(sc.parallelize(data))
+    >>> model = ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(sc.parallelize(data))
     >>> model.transform(DenseVector([1.0,2.0,3.0,4.0]))
     DenseVector([4.0])
 
     .. versionadded:: 1.4.0
     """
-    def __init__(self, numTopFeatures=50):
+    def __init__(self, numTopFeatures=50, selectorType="kbest", percentile=0.1, alpha=0.05):
         self.numTopFeatures = numTopFeatures
-        self.selectorType = ChiSqSelectorType.KBest
+        self.selectorType = selectorType
+        self.percentile = percentile
+        self.alpha = alpha
 
     @since('2.1.0')
     def setNumTopFeatures(self, numTopFeatures):
         """
-        set numTopFeature for feature selection by number of top features
+        set numTopFeature for feature selection by number of top features.
+        Only applicable when selectorType = "kbest".
         """
         self.numTopFeatures = int(numTopFeatures)
-        self.selectorType = ChiSqSelectorType.KBest
         return self
 
     @since('2.1.0')
     def setPercentile(self, percentile):
         """
-        set percentile [0.0, 1.0] for feature selection by percentile
+        set percentile [0.0, 1.0] for feature selection by percentile.
+        Only applicable when selectorType = "percentile".
         """
         self.percentile = float(percentile)
-        self.selectorType = ChiSqSelectorType.Percentile
         return self
 
     @since('2.1.0')
     def setAlpha(self, alpha):
         """
-        set alpha [0.0, 1.0] for feature selection by FPR
+        set alpha [0.0, 1.0] for feature selection by FPR.
+        Only applicable when selectorType = "fpr".
         """
         self.alpha = float(alpha)
-        self.selectorType = ChiSqSelectorType.FPR
+        return self
+
+    @since('2.1.0')
+    def setSelectorType(self, selectorType):
+        """
+        set the selector type of the ChisqSelector.
+        Supported options: "kbest" (default), "percentile" and "fpr".
+        """
+        self.selectorType = str(selectorType)
         return self
 
     @since('1.4.0')
@@ -357,15 +361,8 @@ class ChiSqSelector(object):
                      treated as categorical for each distinct value.
                      Apply feature discretizer before using this function.
         """
-        if self.selectorType == ChiSqSelectorType.KBest:
-            jmodel = callMLlibFunc("fitChiSqSelectorKBest", self.numTopFeatures, data)
-        elif self.selectorType == ChiSqSelectorType.Percentile:
-            jmodel = callMLlibFunc("fitChiSqSelectorPercentile", self.percentile, data)
-        elif self.selectorType == ChiSqSelectorType.FPR:
-            jmodel = callMLlibFunc("fitChiSqSelectorFPR", self.alpha, data)
-        else:
-            raise ValueError("ChiSqSelector type supports KBest(0), Percentile(1) and"
-                             " FPR(2), the current value is: %s" % self.selectorType)
+        jmodel = callMLlibFunc("fitChiSqSelector", self.selectorType, self.numTopFeatures,
+                               self.percentile, self.alpha, data)
         return ChiSqSelectorModel(jmodel)