diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index ea757c5e40c76284185ef0025bc0277097e99faf..1741f19dc911c1140099881f4b0bd843fd3b5355 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -47,6 +47,8 @@ private[ml] trait OneVsRestParams extends PredictorParams { /** * param for the base binary classifier that we reduce multiclass classification into. + * The base classifier input and output columns are ignored in favor of + * the ones specified in [[OneVsRest]]. * @group param */ val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier") @@ -160,6 +162,15 @@ final class OneVsRest(override val uid: String) set(classifier, value.asInstanceOf[ClassifierType]) } + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) } @@ -195,7 +206,11 @@ final class OneVsRest(override val uid: String) val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta) val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) val classifier = getClassifier - classifier.fit(trainingDataset, classifier.labelCol -> labelColName) + val paramMap = new ParamMap() + paramMap.put(classifier.labelCol -> labelColName) + paramMap.put(classifier.featuresCol -> getFeaturesCol) + paramMap.put(classifier.predictionCol -> getPredictionCol) + classifier.fit(trainingDataset, paramMap) }.toArray[ClassificationModel[_, _]] if (handlePersistence) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 75cf5bd4ead4fc49adaf679daa4159bf05a6fae4..3775292f6dca76ca7e453d3080ddff8155dddebb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS @@ -104,6 +105,29 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { ova.fit(datasetWithLabelMetadata) } + test("SPARK-8092: ensure label features and prediction cols are configurable") { + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexed") + + val indexedDataset = labelIndexer + .fit(dataset) + .transform(dataset) + .drop("label") + .withColumnRenamed("features", "f") + + val ova = new OneVsRest() + ova.setClassifier(new LogisticRegression()) + .setLabelCol(labelIndexer.getOutputCol) + .setFeaturesCol("f") + .setPredictionCol("p") + + val ovaModel = ova.fit(indexedDataset) + val transformedDataset = ovaModel.transform(indexedDataset) + val outputFields = transformedDataset.schema.fieldNames.toSet + assert(outputFields.contains("p")) + } + test("SPARK-8049: OneVsRest shouldn't output temp columns") { val logReg = new LogisticRegression() .setMaxIter(1)