Skip to content
Snippets Groups Projects
Commit a140dd77 authored by Yanbo Liang's avatar Yanbo Liang Committed by Xiangrui Meng
Browse files

[SPARK-10027] [ML] [PySpark] Add Python API missing methods for ml.feature

Missing method of ml.feature are listed here:
```StringIndexer``` lacks of parameter ```handleInvalid```.
```StringIndexerModel``` lacks of method ```labels```.
```VectorIndexerModel``` lacks of methods ```numFeatures``` and ```categoryMaps```.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #8313 from yanboliang/spark-10027.
parent 339a5271
No related branches found
No related tags found
No related merge requests found
...@@ -920,7 +920,7 @@ class StandardScalerModel(JavaModel): ...@@ -920,7 +920,7 @@ class StandardScalerModel(JavaModel):
@inherit_doc @inherit_doc
class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid):
""" """
.. note:: Experimental .. note:: Experimental
...@@ -943,19 +943,20 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): ...@@ -943,19 +943,20 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol):
""" """
@keyword_only @keyword_only
def __init__(self, inputCol=None, outputCol=None): def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"):
""" """
__init__(self, inputCol=None, outputCol=None) __init__(self, inputCol=None, outputCol=None, handleInvalid="error")
""" """
super(StringIndexer, self).__init__() super(StringIndexer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
self._setDefault(handleInvalid="error")
kwargs = self.__init__._input_kwargs kwargs = self.__init__._input_kwargs
self.setParams(**kwargs) self.setParams(**kwargs)
@keyword_only @keyword_only
def setParams(self, inputCol=None, outputCol=None): def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"):
""" """
setParams(self, inputCol=None, outputCol=None) setParams(self, inputCol=None, outputCol=None, handleInvalid="error")
Sets params for this StringIndexer. Sets params for this StringIndexer.
""" """
kwargs = self.setParams._input_kwargs kwargs = self.setParams._input_kwargs
...@@ -1235,6 +1236,10 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): ...@@ -1235,6 +1236,10 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol):
>>> model = indexer.fit(df) >>> model = indexer.fit(df)
>>> model.transform(df).head().indexed >>> model.transform(df).head().indexed
DenseVector([1.0, 0.0]) DenseVector([1.0, 0.0])
>>> model.numFeatures
2
>>> model.categoryMaps
{0: {0.0: 0, -1.0: 1}}
>>> indexer.setParams(outputCol="test").fit(df).transform(df).collect()[1].test >>> indexer.setParams(outputCol="test").fit(df).transform(df).collect()[1].test
DenseVector([0.0, 1.0]) DenseVector([0.0, 1.0])
>>> params = {indexer.maxCategories: 3, indexer.outputCol: "vector"} >>> params = {indexer.maxCategories: 3, indexer.outputCol: "vector"}
...@@ -1297,6 +1302,22 @@ class VectorIndexerModel(JavaModel): ...@@ -1297,6 +1302,22 @@ class VectorIndexerModel(JavaModel):
Model fitted by VectorIndexer. Model fitted by VectorIndexer.
""" """
@property
def numFeatures(self):
"""
Number of features, i.e., length of Vectors which this transforms.
"""
return self._call_java("numFeatures")
@property
def categoryMaps(self):
"""
Feature value index. Keys are categorical feature indices (column indices).
Values are maps from original features values to 0-based category indices.
If a feature is not in this map, it is treated as continuous.
"""
return self._call_java("javaCategoryMaps")
@inherit_doc @inherit_doc
class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol): class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol):
......
...@@ -121,7 +121,10 @@ if __name__ == "__main__": ...@@ -121,7 +121,10 @@ if __name__ == "__main__":
("checkpointInterval", "checkpoint interval (>= 1)", None), ("checkpointInterval", "checkpoint interval (>= 1)", None),
("seed", "random seed", "hash(type(self).__name__)"), ("seed", "random seed", "hash(type(self).__name__)"),
("tol", "the convergence tolerance for iterative algorithms", None), ("tol", "the convergence tolerance for iterative algorithms", None),
("stepSize", "Step size to be used for each iteration of optimization.", None)] ("stepSize", "Step size to be used for each iteration of optimization.", None),
("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " +
"out rows with bad values), or error (which will throw an errror). More options may be " +
"added later.", None)]
code = [] code = []
for name, doc, defaultValueStr in shared: for name, doc, defaultValueStr in shared:
param_code = _gen_param_header(name, doc, defaultValueStr) param_code = _gen_param_header(name, doc, defaultValueStr)
......
...@@ -432,6 +432,33 @@ class HasStepSize(Params): ...@@ -432,6 +432,33 @@ class HasStepSize(Params):
return self.getOrDefault(self.stepSize) return self.getOrDefault(self.stepSize)
class HasHandleInvalid(Params):
"""
Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later..
"""
# a placeholder to make it appear in the generated doc
handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.")
def __init__(self):
super(HasHandleInvalid, self).__init__()
#: param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.
self.handleInvalid = Param(self, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.")
def setHandleInvalid(self, value):
"""
Sets the value of :py:attr:`handleInvalid`.
"""
self._paramMap[self.handleInvalid] = value
return self
def getHandleInvalid(self):
"""
Gets the value of handleInvalid or its default value.
"""
return self.getOrDefault(self.handleInvalid)
class DecisionTreeParams(Params): class DecisionTreeParams(Params):
""" """
Mixin for Decision Tree parameters. Mixin for Decision Tree parameters.
...@@ -444,7 +471,7 @@ class DecisionTreeParams(Params): ...@@ -444,7 +471,7 @@ class DecisionTreeParams(Params):
minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.") minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.")
maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.")
def __init__(self): def __init__(self):
super(DecisionTreeParams, self).__init__() super(DecisionTreeParams, self).__init__()
...@@ -460,7 +487,7 @@ class DecisionTreeParams(Params): ...@@ -460,7 +487,7 @@ class DecisionTreeParams(Params):
self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.") self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
#: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. #: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.
self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.") self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.")
def setMaxDepth(self, value): def setMaxDepth(self, value):
""" """
Sets the value of :py:attr:`maxDepth`. Sets the value of :py:attr:`maxDepth`.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment