diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 0e53877de92dba22f62d76b9b3f1b288bdf50660..f6a5f27425d1f3aa6ab1fc829f6c5e12d13efee7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -113,7 +113,8 @@ abstract class Predictor[ * * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. */ - protected def featuresDataType: DataType = new VectorUDT + @DeveloperApi + private[ml] def featuresDataType: DataType = new VectorUDT override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, featuresDataType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala index d7dee8fed2a55f11e6c551e27566255fcaeb0613..f5f37aa77929cd789e1f304c35730cfc668e7c67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala @@ -123,6 +123,7 @@ class AttributeGroup private ( nominalMetadata += nominal.toMetadataImpl(withType = false) case binary: BinaryAttribute => binaryMetadata += binary.toMetadataImpl(withType = false) + case UnresolvedAttribute => } val attrBldr = new MetadataBuilder if (numericMetadata.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala index 65e7e43d5a5b0f000dc12843d0e97481544b0de0..a83febd7de2cc842ae75d7e9ad9b675ee91b2ac9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeType.scala @@ -43,6 +43,12 @@ object AttributeType { Binary } + /** Unresolved type. */ + val Unresolved: AttributeType = { + case object Unresolved extends AttributeType("unresolved") + Unresolved + } + /** * Gets the [[AttributeType]] object from its name. * @param name attribute type name: "numeric", "nominal", or "binary" @@ -54,6 +60,8 @@ object AttributeType { Nominal } else if (name == Binary.name) { Binary + } else if (name == Unresolved.name) { + Unresolved } else { throw new IllegalArgumentException(s"Cannot recognize type $name.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index 5717d6ec2eaec17fe3897f7deed50116d33ff4d3..e8f7f152784a1fd6d065b1b077bdbb2609f21130 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -125,7 +125,13 @@ private[attribute] trait AttributeFactory { */ def fromStructField(field: StructField): Attribute = { require(field.dataType == DoubleType) - fromMetadata(field.metadata.getMetadata(AttributeKeys.ML_ATTR)).withName(field.name) + val metadata = field.metadata + val mlAttr = AttributeKeys.ML_ATTR + if (metadata.contains(mlAttr)) { + fromMetadata(metadata.getMetadata(mlAttr)).withName(field.name) + } else { + UnresolvedAttribute + } } } @@ -535,3 +541,32 @@ object BinaryAttribute extends AttributeFactory { new BinaryAttribute(name, index, values) } } + +/** + * An unresolved attribute. + */ +object UnresolvedAttribute extends Attribute { + + override def attrType: AttributeType = AttributeType.Unresolved + + override def withIndex(index: Int): Attribute = this + + override def isNumeric: Boolean = false + + override def withoutIndex: Attribute = this + + override def isNominal: Boolean = false + + override def name: Option[String] = None + + override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = { + Metadata.empty + } + + override def withoutName: Attribute = this + + override def index: Option[Int] = None + + override def withName(name: String): Attribute = this + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 07ea579d69893c9fdf13f7d640bc92d01b106b44..2e6313ac144850544c9bb02bada509de491f7f63 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.{IntParam, ParamValidators, Params} import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils @@ -375,6 +375,8 @@ class VectorIndexerModel private[ml] ( } case (origAttr: Attribute, featAttr: NumericAttribute) => origAttr.withIndex(featAttr.index.get) + case (origAttr: Attribute, _) => + origAttr } } else { partialFeatureAttributes diff --git a/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala new file mode 100644 index 0000000000000000000000000000000000000000..0a6728ef1f779c9abfe655de5206891b8a6b21ec --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/reduction/OneVsRest.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.reduction + +import java.util.UUID + +import scala.language.existentials + +import org.apache.spark.annotation.{AlphaComponent, Experimental} +import org.apache.spark.ml._ +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.classification.{ClassificationModel, Classifier} +import org.apache.spark.ml.param.Param +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel + +/** + * Params for [[OneVsRest]]. + */ +private[ml] trait OneVsRestParams extends PredictorParams { + + type ClassifierType = Classifier[F, E, M] forSome { + type F + type M <: ClassificationModel[F, M] + type E <: Classifier[F, E, M] + } + + /** + * param for the base binary classifier that we reduce multiclass classification into. + * @group param + */ + val classifier: Param[ClassifierType] = + new Param(this, "classifier", "base binary classifier ") + + /** @group getParam */ + def getClassifier: ClassifierType = $(classifier) + +} + +/** + * Model produced by [[OneVsRest]]. + * Stores the models resulting from training k different classifiers: + * one for each class. + * Each example is scored against all k models and the model with highest score + * is picked to label the example. + * TODO: API may need to change when we introduce a ClassificationModel trait as the public API + * @param parent + * @param labelMetadata Metadata of label column if it exists, or Nominal attribute + * representing the number of classes in training dataset otherwise. + * @param models the binary classification models for reduction. + * The i-th model is produced by testing the i-th class vs the rest. + */ +@AlphaComponent +class OneVsRestModel( + override val parent: OneVsRest, + labelMetadata: Metadata, + val models: Array[_ <: ClassificationModel[_,_]]) + extends Model[OneVsRestModel] with OneVsRestParams { + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) + } + + override def transform(dataset: DataFrame): DataFrame = { + // Check schema + transformSchema(dataset.schema, logging = true) + + // determine the input columns: these need to be passed through + val origCols = dataset.schema.map(f => col(f.name)) + + // add an accumulator column to store predictions of all the models + val accColName = "mbc$acc" + UUID.randomUUID().toString + val init: () => Map[Int, Double] = () => {Map()} + val mapType = MapType(IntegerType, DoubleType, false) + val newDataset = dataset.withColumn(accColName, callUDF(init, mapType)) + + // persist if underlying dataset is not persistent. + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + newDataset.persist(StorageLevel.MEMORY_AND_DISK) + } + + // update the accumulator column with the result of prediction of models + val aggregatedDataset = models.zipWithIndex.foldLeft[DataFrame](newDataset) { + case (df, (model, index)) => { + val rawPredictionCol = model.getRawPredictionCol + val columns = origCols ++ List(col(rawPredictionCol), col(accColName)) + + // add temporary column to store intermediate scores and update + val tmpColName = "mbc$tmp" + UUID.randomUUID().toString + val update: (Map[Int, Double], Vector) => Map[Int, Double] = + (predictions: Map[Int, Double], prediction: Vector) => { + predictions + ((index, prediction(1))) + } + val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol)) + val transformedDataset = model.transform(df).select(columns:_*) + val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf) + val newColumns = origCols ++ List(col(tmpColName)) + + // switch out the intermediate column with the accumulator column + updatedDataset.select(newColumns:_*).withColumnRenamed(tmpColName, accColName) + } + } + + if (handlePersistence) { + newDataset.unpersist() + } + + // output the index of the classifier with highest confidence as prediction + val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => { + predictions.maxBy(_._2)._1.toDouble + } + + // output label and label metadata as prediction + val labelUdf = callUDF(label, DoubleType, col(accColName)) + aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) + } +} + +/** + * :: Experimental :: + * + * Reduction of Multiclass Classification to Binary Classification. + * Performs reduction using one against all strategy. + * For a multiclass classification with k classes, train k models (one per class). + * Each example is scored against all k models and the model with highest score + * is picked to label the example. + */ +@Experimental +final class OneVsRest extends Estimator[OneVsRestModel] with OneVsRestParams { + + /** @group setParam */ + // TODO: Find a better way to do this. Existential Types don't work with Java API so cast needed. + def setClassifier(value: Classifier[_,_,_]): this.type = { + set(classifier, value.asInstanceOf[ClassifierType]) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) + } + + override def fit(dataset: DataFrame): OneVsRestModel = { + // determine number of classes either from metadata if provided, or via computation. + val labelSchema = dataset.schema($(labelCol)) + val computeNumClasses: () => Int = () => { + val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head() + // classes are assumed to be numbered from 0,...,maxLabelIndex + maxLabelIndex.toInt + 1 + } + val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity) + + val multiclassLabeled = dataset.select($(labelCol), $(featuresCol)) + + // persist if underlying dataset is not persistent. + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK) + } + + // create k columns, one for each binary classifier. + val models = Range(0, numClasses).par.map { index => + + val label: Double => Double = (label: Double) => { + if (label.toInt == index) 1.0 else 0.0 + } + + // generate new label metadata for the binary problem. + // TODO: use when ... otherwise after SPARK-7321 is merged + val labelUDF = callUDF(label, DoubleType, col($(labelCol))) + val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() + val labelColName = "mc2b$" + index + val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta) + val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) + val classifier = getClassifier + classifier.fit(trainingDataset, classifier.labelCol -> labelColName) + }.toArray[ClassificationModel[_,_]] + + if (handlePersistence) { + multiclassLabeled.unpersist() + } + + // extract label metadata from label column if present, or create a nominal attribute + // to output the number of labels + val labelAttribute = Attribute.fromStructField(labelSchema) match { + case _: NumericAttribute | UnresolvedAttribute => { + NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + } + case attr: Attribute => attr + } + copyValues(new OneVsRestModel(this, labelAttribute.toMetadata(), models)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala index c84c8b4eb744fbe752df07e8a640d5c858a5646d..56075c9a6b39f92261cb081173435517f1f9729b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala @@ -20,8 +20,7 @@ package org.apache.spark.ml.util import scala.collection.immutable.HashMap import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute, - NumericAttribute} +import org.apache.spark.ml.attribute._ import org.apache.spark.sql.types.StructField @@ -39,9 +38,9 @@ object MetadataUtils { */ def getNumClasses(labelSchema: StructField): Option[Int] = { Attribute.fromStructField(labelSchema) match { - case numAttr: NumericAttribute => None case binAttr: BinaryAttribute => Some(2) case nomAttr: NominalAttribute => nomAttr.getNumValues + case _: NumericAttribute | UnresolvedAttribute => None } } @@ -65,7 +64,7 @@ object MetadataUtils { Iterator() } else { attr match { - case numAttr: NumericAttribute => Iterator() + case _: NumericAttribute | UnresolvedAttribute => Iterator() case binAttr: BinaryAttribute => Iterator(idx -> 2) case nomAttr: NominalAttribute => nomAttr.getNumValues match { diff --git a/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..40a90ae9ded60e68084eff4f4251674042483d33 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/reduction/JavaOneVsRestSuite.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.reduction; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import static scala.collection.JavaConversions.seqAsJavaList; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegression; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; + +public class JavaOneVsRestSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + private transient DataFrame dataset; + private transient JavaRDD<LabeledPoint> datasetRDD; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite"); + jsql = new SQLContext(jsc); + int nPoints = 3; + + /** + * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2. + * As a result, we are actually drawing samples from probability distribution of built model. + */ + double[] weights = { + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 }; + + double[] xMean = {5.843, 3.057, 3.758, 1.199}; + double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; + List<LabeledPoint> points = seqAsJavaList(generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42)); + datasetRDD = jsc.parallelize(points, 2); + dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void oneVsRestDefaultParams() { + OneVsRest ova = new OneVsRest(); + ova.setClassifier(new LogisticRegression()); + Assert.assertEquals(ova.getLabelCol() , "label"); + Assert.assertEquals(ova.getPredictionCol() , "prediction"); + OneVsRestModel ovaModel = ova.fit(dataset); + DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction"); + predictions.collectAsList(); + Assert.assertEquals(ovaModel.getLabelCol(), "label"); + Assert.assertEquals(ovaModel.getPredictionCol() , "prediction"); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index 3e1a7196e37cb9d5f182accea5d172f21a1a6040..ec9b717e41ce8c41af435d4b7e022604cc25c197 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.attribute import org.scalatest.FunSuite -import org.apache.spark.sql.types.{DoubleType, MetadataBuilder, Metadata} +import org.apache.spark.sql.types._ class AttributeSuite extends FunSuite { @@ -209,4 +209,12 @@ class AttributeSuite extends FunSuite { intercept[IllegalArgumentException](attr.withName("")) intercept[IllegalArgumentException](attr.withIndex(-1)) } + + test("attribute from struct field") { + val metadata = NumericAttribute.defaultAttr.withName("label").toMetadata() + val fldWithoutMeta = new StructField("x", DoubleType, false, Metadata.empty) + assert(Attribute.fromStructField(fldWithoutMeta) == UnresolvedAttribute) + val fldWithMeta = new StructField("x", DoubleType, false, metadata) + assert(Attribute.fromStructField(fldWithMeta).isNumeric) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..ebec7c68e814438d04fb9a832252ada2eec826cd --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/reduction/OneVsRestSuite.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.reduction + +import org.scalatest.FunSuite + +import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression} +import org.apache.spark.ml.util.MetadataUtils +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SQLContext} + +class OneVsRestSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + @transient var dataset: DataFrame = _ + @transient var rdd: RDD[LabeledPoint] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + val nPoints = 1000 + + /** + * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2. + * As a result, we are actually drawing samples from probability distribution of built model. + */ + val weights = Array( + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) + + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + rdd = sc.parallelize(generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 2) + dataset = sqlContext.createDataFrame(rdd) + } + + test("one-vs-rest: default params") { + val numClasses = 3 + val ova = new OneVsRest() + ova.setClassifier(new LogisticRegression) + assert(ova.getLabelCol === "label") + assert(ova.getPredictionCol === "prediction") + val ovaModel = ova.fit(dataset) + assert(ovaModel.models.size === numClasses) + val transformedDataset = ovaModel.transform(dataset) + + // check for label metadata in prediction col + val predictionColSchema = transformedDataset.schema(ovaModel.getPredictionCol) + assert(MetadataUtils.getNumClasses(predictionColSchema) === Some(3)) + + val ovaResults = transformedDataset + .select("prediction", "label") + .map(row => (row.getDouble(0), row.getDouble(1))) + + val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(numClasses) + lr.optimizer.setRegParam(0.1).setNumIterations(100) + + val model = lr.run(rdd) + val results = model.predict(rdd.map(_.features)).zip(rdd.map(_.label)) + // determine the #confusion matrix in each class. + // bound how much error we allow compared to multinomial logistic regression. + val expectedMetrics = new MulticlassMetrics(results) + val ovaMetrics = new MulticlassMetrics(ovaResults) + assert(expectedMetrics.confusionMatrix ~== ovaMetrics.confusionMatrix absTol 400) + } + + test("one-vs-rest: pass label metadata correctly during train") { + val numClasses = 3 + val ova = new OneVsRest() + ova.setClassifier(new MockLogisticRegression) + + val labelMetadata = NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + val labelWithMetadata = dataset("label").as("label", labelMetadata.toMetadata()) + val features = dataset("features").as("features") + val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features) + ova.fit(datasetWithLabelMetadata) + } +} + +private class MockLogisticRegression extends LogisticRegression { + + setMaxIter(1) + + override protected def train(dataset: DataFrame): LogisticRegressionModel = { + val labelSchema = dataset.schema($(labelCol)) + // check for label attribute propagation. + assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2)) + super.train(dataset) + } +}