Skip to content
Snippets Groups Projects
Commit d4d762f2 authored by Ram Sriharsha's avatar Ram Sriharsha Committed by Joseph K. Bradley
Browse files

[SPARK-8092] [ML] Allow OneVsRest Classifier feature and label column names to be configurable.

The base classifier input and output columns are ignored in favor of  the ones specified in OneVsRest.

Author: Ram Sriharsha <rsriharsha@hw11853.local>

Closes #6631 from harsha2010/SPARK-8092 and squashes the following commits:

6591dc6 [Ram Sriharsha] add documentation for params
b7024b1 [Ram Sriharsha] cleanup
f0e2bfb [Ram Sriharsha] merge with master
108d3d7 [Ram Sriharsha] merge with master
4f74126 [Ram Sriharsha] Allow label/ features columns to be configurable
parent d249636e
No related branches found
No related tags found
No related merge requests found
......@@ -47,6 +47,8 @@ private[ml] trait OneVsRestParams extends PredictorParams {
/**
* param for the base binary classifier that we reduce multiclass classification into.
* The base classifier input and output columns are ignored in favor of
* the ones specified in [[OneVsRest]].
* @group param
*/
val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier")
......@@ -160,6 +162,15 @@ final class OneVsRest(override val uid: String)
set(classifier, value.asInstanceOf[ClassifierType])
}
/** @group setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)
/** @group setParam */
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
/** @group setParam */
def setPredictionCol(value: String): this.type = set(predictionCol, value)
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
}
......@@ -195,7 +206,11 @@ final class OneVsRest(override val uid: String)
val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
val classifier = getClassifier
classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
val paramMap = new ParamMap()
paramMap.put(classifier.labelCol -> labelColName)
paramMap.put(classifier.featuresCol -> getFeaturesCol)
paramMap.put(classifier.predictionCol -> getPredictionCol)
classifier.fit(trainingDataset, paramMap)
}.toArray[ClassificationModel[_, _]]
if (handlePersistence) {
......
......@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
......@@ -104,6 +105,29 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
ova.fit(datasetWithLabelMetadata)
}
test("SPARK-8092: ensure label features and prediction cols are configurable") {
val labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexed")
val indexedDataset = labelIndexer
.fit(dataset)
.transform(dataset)
.drop("label")
.withColumnRenamed("features", "f")
val ova = new OneVsRest()
ova.setClassifier(new LogisticRegression())
.setLabelCol(labelIndexer.getOutputCol)
.setFeaturesCol("f")
.setPredictionCol("p")
val ovaModel = ova.fit(indexedDataset)
val transformedDataset = ovaModel.transform(indexedDataset)
val outputFields = transformedDataset.schema.fieldNames.toSet
assert(outputFields.contains("p"))
}
test("SPARK-8049: OneVsRest shouldn't output temp columns") {
val logReg = new LogisticRegression()
.setMaxIter(1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment