diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py new file mode 100644 index 0000000000000000000000000000000000000000..a383bd0c0d26f37ca65f840cd95887bd9aa61018 --- /dev/null +++ b/python/pyspark/ml/tuning.py @@ -0,0 +1,94 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__all__ = ['ParamGridBuilder'] + + +class ParamGridBuilder(object): + """ + Builder for a param grid used in grid search-based model selection. + + >>> from classification import LogisticRegression + >>> lr = LogisticRegression() + >>> output = ParamGridBuilder().baseOn({lr.labelCol: 'l'}) \ + .baseOn([lr.predictionCol, 'p']) \ + .addGrid(lr.regParam, [1.0, 2.0, 3.0]) \ + .addGrid(lr.maxIter, [1, 5]) \ + .addGrid(lr.featuresCol, ['f']) \ + .build() + >>> expected = [ \ +{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] + >>> fail_count = 0 + >>> for e in expected: + ... if e not in output: + ... fail_count += 1 + >>> if len(expected) != len(output): + ... fail_count += 1 + >>> fail_count + 0 + """ + + def __init__(self): + self._param_grid = {} + + def addGrid(self, param, values): + """ + Sets the given parameters in this grid to fixed values. + """ + self._param_grid[param] = values + + return self + + def baseOn(self, *args): + """ + Sets the given parameters in this grid to fixed values. + Accepts either a parameter dictionary or a list of (parameter, value) pairs. + """ + if isinstance(args[0], dict): + self.baseOn(*args[0].items()) + else: + for (param, value) in args: + self.addGrid(param, [value]) + + return self + + def build(self): + """ + Builds and returns all combinations of parameters specified + by the param grid. + """ + param_maps = [{}] + for (param, values) in self._param_grid.items(): + new_param_maps = [] + for value in values: + for old_map in param_maps: + copied_map = old_map.copy() + copied_map[param] = value + new_param_maps.append(copied_map) + param_maps = new_param_maps + + return param_maps + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/run-tests b/python/run-tests index 88b63b84fdc27247ebd5c1509766822c82d2bad2..0e0eee3564e7c0bcfb1b48aa40361326c24c97e4 100755 --- a/python/run-tests +++ b/python/run-tests @@ -98,6 +98,7 @@ function run_ml_tests() { echo "Run ml tests ..." run_test "pyspark/ml/feature.py" run_test "pyspark/ml/classification.py" + run_test "pyspark/ml/tuning.py" run_test "pyspark/ml/tests.py" }