diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index a269dbf030e7c79e96dee7ac9c5aac937dac8eff..c49bc4ff124bd389f82058ade37faa2b6182d669 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -140,7 +140,7 @@ definitions of positive and negative labels is straightforward. #### Label based metrics Opposed to binary classification where there are only two possible labels, multiclass classification problems have many -possible labels and so the concept of label-based metrics is introduced. Overall precision measures precision across all +possible labels and so the concept of label-based metrics is introduced. Accuracy measures precision across all labels - the number of times any class was predicted correctly (true positives) normalized by the number of data points. Precision by label considers only one class, and measures the number of time a specific label was predicted correctly normalized by the number of times that label appears in the output. @@ -182,20 +182,10 @@ $$\hat{\delta}(x) = \begin{cases}1 & \text{if $x = 0$}, \\ 0 & \text{otherwise}. </td> </tr> <tr> - <td>Overall Precision</td> - <td>$PPV = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - - \mathbf{y}_i\right)$</td> - </tr> - <tr> - <td>Overall Recall</td> - <td>$TPR = \frac{TP}{TP + FN} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + <td>Accuracy</td> + <td>$ACC = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - \mathbf{y}_i\right)$</td> </tr> - <tr> - <td>Overall F1-measure</td> - <td>$F1 = 2 \cdot \left(\frac{PPV \cdot TPR} - {PPV + TPR}\right)$</td> - </tr> <tr> <td>Precision by label</td> <td>$PPV(\ell) = \frac{TP}{TP + FP} = diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 0b84e0a3fa7847bc94b4df1f3ed89ce11a84d033..794b1e7d9d8819bfd516aa493c26be1c1c6e6666 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -39,16 +39,16 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid def this() = this(Identifiable.randomUID("mcEval")) /** - * param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`, - * `"weightedPrecision"`, `"weightedRecall"`, `"accuracy"`) + * param for metric name in evaluation (supports `"f1"` (default), `"weightedPrecision"`, + * `"weightedRecall"`, `"accuracy"`) * @group param */ @Since("1.5.0") val metricName: Param[String] = { - val allowedParams = ParamValidators.inArray(Array("f1", "precision", - "recall", "weightedPrecision", "weightedRecall", "accuracy")) + val allowedParams = ParamValidators.inArray(Array("f1", "weightedPrecision", + "weightedRecall", "accuracy")) new Param(this, "metricName", "metric name in evaluation " + - "(f1|precision|recall|weightedPrecision|weightedRecall|accuracy)", allowedParams) + "(f1|weightedPrecision|weightedRecall|accuracy)", allowedParams) } /** @group getParam */ @@ -82,8 +82,6 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid val metrics = new MulticlassMetrics(predictionAndLabels) val metric = $(metricName) match { case "f1" => metrics.weightedFMeasure - case "precision" => metrics.accuracy - case "recall" => metrics.accuracy case "weightedPrecision" => metrics.weightedPrecision case "weightedRecall" => metrics.weightedRecall case "accuracy" => metrics.accuracy diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala index 522f6675d7f46d3bbd227e9fc5ff2228f9239c04..1a3a8a13a2d093ea58e140a848780158dcc66dd8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -33,7 +33,7 @@ class MulticlassClassificationEvaluatorSuite val evaluator = new MulticlassClassificationEvaluator() .setPredictionCol("myPrediction") .setLabelCol("myLabel") - .setMetricName("recall") + .setMetricName("accuracy") testDefaultReadWrite(evaluator) } diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index b8b2b37af553dde81fe9106171733e23f139449b..c480525e9bd3a21c0d14312e23f9178a76f18d19 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -258,9 +258,7 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction") >>> evaluator.evaluate(dataset) 0.66... - >>> evaluator.evaluate(dataset, {evaluator.metricName: "precision"}) - 0.66... - >>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"}) + >>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"}) 0.66... .. versionadded:: 1.5.0