Skip to content
Snippets Groups Projects
Commit 1ffa8cb9 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-7329] [MLLIB] simplify ParamGridBuilder impl

as suggested by justinuang on #5601.

Author: Xiangrui Meng <meng@databricks.com>

Closes #5873 from mengxr/SPARK-7329 and squashes the following commits:

d08f9cf [Xiangrui Meng] simplify tests
b7a7b9b [Xiangrui Meng] simplify grid build
parent 9e25b09f
No related branches found
No related tags found
No related merge requests found
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
# #
import itertools
__all__ = ['ParamGridBuilder'] __all__ = ['ParamGridBuilder']
...@@ -37,14 +39,10 @@ class ParamGridBuilder(object): ...@@ -37,14 +39,10 @@ class ParamGridBuilder(object):
{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ {lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ {lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] {lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
>>> fail_count = 0 >>> len(output) == len(expected)
>>> for e in expected: True
... if e not in output: >>> all([m in expected for m in output])
... fail_count += 1 True
>>> if len(expected) != len(output):
... fail_count += 1
>>> fail_count
0
""" """
def __init__(self): def __init__(self):
...@@ -76,17 +74,9 @@ class ParamGridBuilder(object): ...@@ -76,17 +74,9 @@ class ParamGridBuilder(object):
Builds and returns all combinations of parameters specified Builds and returns all combinations of parameters specified
by the param grid. by the param grid.
""" """
param_maps = [{}] keys = self._param_grid.keys()
for (param, values) in self._param_grid.items(): grid_values = self._param_grid.values()
new_param_maps = [] return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
for value in values:
for old_map in param_maps:
copied_map = old_map.copy()
copied_map[param] = value
new_param_maps.append(copied_map)
param_maps = new_param_maps
return param_maps
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment