From 97d4cd07406a3d2fd5be83b009988d8bc320b524 Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Tue, 2 Jun 2015 16:51:17 -0700
Subject: [PATCH] [SPARK-8049] [MLLIB] drop tmp col from OneVsRest output

The temporary column should be dropped after we get the prediction column. harsha2010

Author: Xiangrui Meng <meng@databricks.com>

Closes #6592 from mengxr/SPARK-8049 and squashes the following commits:

1d89107 [Xiangrui Meng] use SparkFunSuite
6ee70de [Xiangrui Meng] drop tmp col from OneVsRest output

(cherry picked from commit 89f21f66b5549524d1a6e4fb576a4f80d9fef903)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
---
 .../org/apache/spark/ml/classification/OneVsRest.scala   | 1 +
 .../apache/spark/ml/classification/OneVsRestSuite.scala  | 9 +++++++++
 2 files changed, 10 insertions(+)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 7b726da388..825f9ed1b5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -131,6 +131,7 @@ final class OneVsRestModel private[ml] (
     // output label and label metadata as prediction
     val labelUdf = callUDF(label, DoubleType, col(accColName))
     aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
+      .drop(accColName)
   }
 }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 770b56890f..1b354d077d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -94,6 +94,15 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
     val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
     ova.fit(datasetWithLabelMetadata)
   }
+
+  test("SPARK-8049: OneVsRest shouldn't output temp columns") {
+    val logReg = new LogisticRegression()
+      .setMaxIter(1)
+    val ovr = new OneVsRest()
+      .setClassifier(logReg)
+    val output = ovr.fit(dataset).transform(dataset)
+    assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
+  }
 }
 
 private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
-- 
GitLab