diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 5148ad02d93aabdf0b0fae543823677cc306e4cd..557a53cc2314abd79f307e2defd6fb6f910b46a5 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -114,9 +114,15 @@ Continuing the earlier example: {% include_example java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java %} </div> -<!--- TODO: Add python model summaries once implemented --> <div data-lang="python" markdown="1"> -Logistic regression model summary is not yet supported in Python. +[`LogisticRegressionTrainingSummary`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionSummary) +provides a summary for a +[`LogisticRegressionModel`](api/python/pyspark.ml.html#pyspark.ml.classification.LogisticRegressionModel). +Currently, only binary classification is supported. Support for multiclass model summaries will be added in the future. + +Continuing the earlier example: + +{% include_example python/ml/logistic_regression_summary_example.py %} </div> </div> diff --git a/examples/src/main/python/ml/logistic_regression_summary_example.py b/examples/src/main/python/ml/logistic_regression_summary_example.py new file mode 100644 index 0000000000000000000000000000000000000000..bd440a1fbe8df752350e87141820a6957c0ecdd4 --- /dev/null +++ b/examples/src/main/python/ml/logistic_regression_summary_example.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +# $example on$ +from pyspark.ml.classification import LogisticRegression +# $example off$ +from pyspark.sql import SparkSession + +""" +An example demonstrating Logistic Regression Summary. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression_summary_example.py +""" + +if __name__ == "__main__": + spark = SparkSession \ + .builder \ + .appName("LogisticRegressionSummary") \ + .getOrCreate() + + # Load training data + training = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + lr = LogisticRegression(maxIter=10, regParam=0.3, elasticNetParam=0.8) + + # Fit the model + lrModel = lr.fit(training) + + # $example on$ + # Extract the summary from the returned LogisticRegressionModel instance trained + # in the earlier example + trainingSummary = lrModel.summary + + # Obtain the objective per iteration + objectiveHistory = trainingSummary.objectiveHistory + print("objectiveHistory:") + for objective in objectiveHistory: + print(objective) + + # Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. + trainingSummary.roc.show() + print("areaUnderROC: " + str(trainingSummary.areaUnderROC)) + + # Set the model threshold to maximize F-Measure + fMeasure = trainingSummary.fMeasureByThreshold + maxFMeasure = fMeasure.groupBy().max('F-Measure').select('max(F-Measure)').head() + bestThreshold = fMeasure.where(fMeasure['F-Measure'] == maxFMeasure['max(F-Measure)']) \ + .select('threshold').head()['threshold'] + lr.setThreshold(bestThreshold) + # $example off$ + + spark.stop()