Skip to content
Snippets Groups Projects
Commit 96c4846d authored by Joseph K. Bradley's avatar Joseph K. Bradley Committed by Xiangrui Meng
Browse files

[SPARK-7573] [ML] OneVsRest cleanups

Minor cleanups discussed with [~mengxr]:
* move OneVsRest from reduction to classification sub-package
* make model constructor private

Some doc cleanups too

CC: harsha2010  Could you please verify this looks OK?  Thanks!

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #6097 from jkbradley/onevsrest-cleanup and squashes the following commits:

4ecd48d [Joseph K. Bradley] org imports
430b065 [Joseph K. Bradley] moved OneVsRest from reduction subpackage to classification.  small java doc style fixes
9f8b9b9 [Joseph K. Bradley] Small cleanups to OneVsRest.  Made model constructor private to ml package.
parent f0c1bc34
No related branches found
No related tags found
No related merge requests found
......@@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.ml.reduction
package org.apache.spark.ml.classification
import java.util.UUID
......@@ -24,7 +24,6 @@ import scala.language.existentials
import org.apache.spark.annotation.{AlphaComponent, Experimental}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.classification.{ClassificationModel, Classifier}
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.linalg.Vector
......@@ -57,20 +56,21 @@ private[ml] trait OneVsRestParams extends PredictorParams {
}
/**
* :: AlphaComponent ::
*
* Model produced by [[OneVsRest]].
* Stores the models resulting from training k different classifiers:
* one for each class.
* Each example is scored against all k models and the model with highest score
* This stores the models resulting from training k binary classifiers: one for each class.
* Each example is scored against all k models, and the model with the highest score
* is picked to label the example.
* TODO: API may need to change when we introduce a ClassificationModel trait as the public API
* @param parent
*
* @param labelMetadata Metadata of label column if it exists, or Nominal attribute
* representing the number of classes in training dataset otherwise.
* @param models the binary classification models for reduction.
* The i-th model is produced by testing the i-th class vs the rest.
* @param models The binary classification models for the reduction.
* The i-th model is produced by testing the i-th class (taking label 1) vs the rest
* (taking label 0).
*/
@AlphaComponent
class OneVsRestModel(
class OneVsRestModel private[ml] (
override val parent: OneVsRest,
labelMetadata: Metadata,
val models: Array[_ <: ClassificationModel[_,_]])
......@@ -90,7 +90,7 @@ class OneVsRestModel(
// add an accumulator column to store predictions of all the models
val accColName = "mbc$acc" + UUID.randomUUID().toString
val init: () => Map[Int, Double] = () => {Map()}
val mapType = MapType(IntegerType, DoubleType, false)
val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false)
val newDataset = dataset.withColumn(accColName, callUDF(init, mapType))
// persist if underlying dataset is not persistent.
......@@ -101,7 +101,7 @@ class OneVsRestModel(
// update the accumulator column with the result of prediction of models
val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) {
case (df, (model, index)) => {
case (df, (model, index)) =>
val rawPredictionCol = model.getRawPredictionCol
val columns = origCols ++ List(col(rawPredictionCol), col(accColName))
......@@ -110,7 +110,7 @@ class OneVsRestModel(
val update: (Map[Int, Double], Vector) => Map[Int, Double] =
(predictions: Map[Int, Double], prediction: Vector) => {
predictions + ((index, prediction(1)))
}
}
val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol))
val transformedDataset = model.transform(df).select(columns:_*)
val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf)
......@@ -118,7 +118,6 @@ class OneVsRestModel(
// switch out the intermediate column with the accumulator column
updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName)
}
}
if (handlePersistence) {
......@@ -149,8 +148,8 @@ class OneVsRestModel(
final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
/** @group setParam */
// TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed.
def setClassifier(value: Classifier[_,_,_]): this.type = {
// TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed
set(classifier, value.asInstanceOf[ClassifierType])
}
......@@ -201,9 +200,8 @@ final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams {
// extract label metadata from label column if present, or create a nominal attribute
// to output the number of labels
val labelAttribute = Attribute.fromStructField(labelSchema) match {
case _: NumericAttribute | UnresolvedAttribute => {
case _: NumericAttribute | UnresolvedAttribute =>
NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
}
case attr: Attribute => attr
}
copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models))
......
......@@ -15,21 +15,20 @@
* limitations under the License.
*/
package org.apache.spark.ml.reduction;
package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.List;
import static scala.collection.JavaConversions.seqAsJavaList;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import static scala.collection.JavaConversions.seqAsJavaList;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegression;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
......@@ -48,10 +47,8 @@ public class JavaOneVsRestSuite implements Serializable {
jsql = new SQLContext(jsc);
int nPoints = 3;
/**
* The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
* As a result, we are actually drawing samples from probability distribution of built model.
*/
// The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2.
// As a result, we are drawing samples from probability distribution of an actual model.
double[] weights = {
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682 };
......
......@@ -15,12 +15,11 @@
* limitations under the License.
*/
package org.apache.spark.ml.reduction
package org.apache.spark.ml.classification
import org.scalatest.FunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
......@@ -42,10 +41,8 @@ class OneVsRestSuite extends FunSuite with MLlibTestSparkContext {
sqlContext = new SQLContext(sc)
val nPoints = 1000
/**
* The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
* As a result, we are actually drawing samples from probability distribution of built model.
*/
// The following weights and xMean/xVariance are computed from iris dataset with lambda=0.2.
// As a result, we are drawing samples from probability distribution of an actual model.
val weights = Array(
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment