From fe16fd0b8b717f01151bc659ec3299dab091c97a Mon Sep 17 00:00:00 2001
From: Yanbo Liang <ybliang8@gmail.com>
Date: Mon, 31 Aug 2015 16:06:38 -0700
Subject: [PATCH] [SPARK-10349] [ML] OneVsRest use 'when ... otherwise' not UDF
 to generate new label at binary reduction

Currently OneVsRest use UDF to generate new binary label during training.
Considering that [SPARK-7321](https://issues.apache.org/jira/browse/SPARK-7321) has been merged, we can use ```when ... otherwise``` which will be more efficiency.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #8519 from yanboliang/spark-10349.
---
 .../org/apache/spark/ml/classification/OneVsRest.scala | 10 ++--------
 1 file changed, 2 insertions(+), 8 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index c62e132f5d..debc164bf2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -91,7 +91,6 @@ final class OneVsRestModel private[ml] (
     // add an accumulator column to store predictions of all the models
     val accColName = "mbc$acc" + UUID.randomUUID().toString
     val initUDF = udf { () => Map[Int, Double]() }
-    val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false)
     val newDataset = dataset.withColumn(accColName, initUDF())
 
     // persist if underlying dataset is not persistent.
@@ -195,16 +194,11 @@ final class OneVsRest(override val uid: String)
 
     // create k columns, one for each binary classifier.
     val models = Range(0, numClasses).par.map { index =>
-      val labelUDF = udf { (label: Double) =>
-        if (label.toInt == index) 1.0 else 0.0
-      }
-
       // generate new label metadata for the binary problem.
-      // TODO: use when ... otherwise after SPARK-7321 is merged
       val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
       val labelColName = "mc2b$" + index
-      val trainingDataset =
-        multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta)
+      val trainingDataset = multiclassLabeled.withColumn(
+        labelColName, when(col($(labelCol)) === index.toDouble, 1.0).otherwise(0.0), newLabelMeta)
       val classifier = getClassifier
       val paramMap = new ParamMap()
       paramMap.put(classifier.labelCol -> labelColName)
-- 
GitLab