Skip to content
Snippets Groups Projects
Commit 89f21f66 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-8049] [MLLIB] drop tmp col from OneVsRest output

The temporary column should be dropped after we get the prediction column. harsha2010

Author: Xiangrui Meng <meng@databricks.com>

Closes #6592 from mengxr/SPARK-8049 and squashes the following commits:

1d89107 [Xiangrui Meng] use SparkFunSuite
6ee70de [Xiangrui Meng] drop tmp col from OneVsRest output
parent 605ddbb2
No related branches found
No related tags found
No related merge requests found
......@@ -131,6 +131,7 @@ final class OneVsRestModel private[ml] (
// output label and label metadata as prediction
val labelUdf = callUDF(label, DoubleType, col(accColName))
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
.drop(accColName)
}
}
......
......@@ -93,6 +93,15 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
ova.fit(datasetWithLabelMetadata)
}
test("SPARK-8049: OneVsRest shouldn't output temp columns") {
val logReg = new LogisticRegression()
.setMaxIter(1)
val ovr = new OneVsRest()
.setClassifier(logReg)
val output = ovr.fit(dataset).transform(dataset)
assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
}
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment