diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java new file mode 100644 index 0000000000000000000000000000000000000000..22ba68d8c354caffe854b2e12e5d18e1b1da9710 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.List; + +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import org.apache.spark.sql.api.java.Row; +import org.apache.spark.SparkConf; + +/** + * A simple text classification pipeline that recognizes "spark" from input text. It uses the Java + * bean classes {@link LabeledDocument} and {@link Document} defined in the Scala counterpart of + * this example {@link SimpleTextClassificationPipeline}. Run with + * <pre> + * bin/run-example ml.JavaSimpleTextClassificationPipeline + * </pre> + */ +public class JavaSimpleTextClassificationPipeline { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"); + JavaSparkContext jsc = new JavaSparkContext(conf); + JavaSQLContext jsql = new JavaSQLContext(jsc); + + // Prepare training documents, which are labeled. + List<LabeledDocument> localTraining = Lists.newArrayList( + new LabeledDocument(0L, "a b c d e spark", 1.0), + new LabeledDocument(1L, "b d", 0.0), + new LabeledDocument(2L, "spark f g h", 1.0), + new LabeledDocument(3L, "hadoop mapreduce", 0.0)); + JavaSchemaRDD training = + jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); + HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + + // Fit the pipeline to training documents. + PipelineModel model = pipeline.fit(training); + + // Prepare test documents, which are unlabeled. + List<Document> localTest = Lists.newArrayList( + new Document(4L, "spark i j k"), + new Document(5L, "l m n"), + new Document(6L, "mapreduce spark"), + new Document(7L, "apache hadoop")); + JavaSchemaRDD test = + jsql.applySchema(jsc.parallelize(localTest), Document.class); + + // Make predictions on test documents. + model.transform(test).registerAsTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + for (Row r: predictions.collect()) { + System.out.println(r); + } + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala new file mode 100644 index 0000000000000000000000000000000000000000..ee7897d9062d9e4b119afacb7b8e12f25c73c60d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.beans.BeanInfo + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.sql.SQLContext + +@BeanInfo +case class LabeledDocument(id: Long, text: String, label: Double) + +@BeanInfo +case class Document(id: Long, text: String) + +/** + * A simple text classification pipeline that recognizes "spark" from input text. This is to show + * how to create and configure an ML pipeline. Run with + * {{{ + * bin/run-example ml.SimpleTextClassificationPipeline + * }}} + */ +object SimpleTextClassificationPipeline { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Prepare training documents, which are labeled. + val training = sparkContext.parallelize(Seq( + LabeledDocument(0L, "a b c d e spark", 1.0), + LabeledDocument(1L, "b d", 0.0), + LabeledDocument(2L, "spark f g h", 1.0), + LabeledDocument(3L, "hadoop mapreduce", 0.0))) + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") + val hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01) + val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + + // Fit the pipeline to training documents. + val model = pipeline.fit(training) + + // Prepare test documents, which are unlabeled. + val test = sparkContext.parallelize(Seq( + Document(4L, "spark i j k"), + Document(5L, "l m n"), + Document(6L, "mapreduce spark"), + Document(7L, "apache hadoop"))) + + // Make predictions on test documents. + model.transform(test) + .select('id, 'text, 'score, 'prediction) + .collect() + .foreach(println) + } +} diff --git a/mllib/pom.xml b/mllib/pom.xml index 87a7ddaba97f2bd6094a12c4ad8ad3309fbd0e35..dd68b27a78bdc1d0c7ed611a4475558bc879f6c1 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -100,6 +100,11 @@ <artifactId>junit-interface</artifactId> <scope>test</scope> </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-all</artifactId> + <scope>test</scope> + </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-streaming_${scala.binary.version}</artifactId> diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala new file mode 100644 index 0000000000000000000000000000000000000000..fdbee743e81774b2dbada6c1d831b00dbb1820d4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.annotation.varargs +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.api.java.JavaSchemaRDD + +/** + * :: AlphaComponent :: + * Abstract class for estimators that fit models to data. + */ +@AlphaComponent +abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { + + /** + * Fits a single model to the input data with optional parameters. + * + * @param dataset input dataset + * @param paramPairs optional list of param pairs (overwrite embedded params) + * @return fitted model + */ + @varargs + def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = { + val map = new ParamMap().put(paramPairs: _*) + fit(dataset, map) + } + + /** + * Fits a single model to the input data with provided parameter map. + * + * @param dataset input dataset + * @param paramMap parameter map + * @return fitted model + */ + def fit(dataset: SchemaRDD, paramMap: ParamMap): M + + /** + * Fits multiple models to the input data with multiple sets of parameters. + * The default implementation uses a for loop on each parameter map. + * Subclasses could overwrite this to optimize multi-model training. + * + * @param dataset input dataset + * @param paramMaps an array of parameter maps + * @return fitted models, matching the input parameter maps + */ + def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { + paramMaps.map(fit(dataset, _)) + } + + // Java-friendly versions of fit. + + /** + * Fits a single model to the input data with optional parameters. + * + * @param dataset input dataset + * @param paramPairs optional list of param pairs (overwrite embedded params) + * @return fitted model + */ + @varargs + def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = { + fit(dataset.schemaRDD, paramPairs: _*) + } + + /** + * Fits a single model to the input data with provided parameter map. + * + * @param dataset input dataset + * @param paramMap parameter map + * @return fitted model + */ + def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = { + fit(dataset.schemaRDD, paramMap) + } + + /** + * Fits multiple models to the input data with multiple sets of parameters. + * + * @param dataset input dataset + * @param paramMaps an array of parameter maps + * @return fitted models, matching the input parameter maps + */ + def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = { + fit(dataset.schemaRDD, paramMaps).asJava + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..db563dd550e56ed271d5237f20bfbaa24b745bc2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.SchemaRDD + +/** + * :: AlphaComponent :: + * Abstract class for evaluators that compute metrics from predictions. + */ +@AlphaComponent +abstract class Evaluator extends Identifiable { + + /** + * Evaluates the output. + * + * @param dataset a dataset that contains labels/observations and predictions. + * @param paramMap parameter map that specifies the input columns and output metrics + * @return metric + */ + def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala new file mode 100644 index 0000000000000000000000000000000000000000..cd84b05bfb4969557065acc03982471d3072b8f2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import java.util.UUID + +/** + * Object with a unique id. + */ +private[ml] trait Identifiable extends Serializable { + + /** + * A unique id for the object. The default implementation concatenates the class name, "-", and 8 + * random hex chars. + */ + private[ml] val uid: String = + this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala new file mode 100644 index 0000000000000000000000000000000000000000..cae5082b51196d2920fa2d8aef70a68484842b84 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.ParamMap + +/** + * :: AlphaComponent :: + * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]]. + * + * @tparam M model type + */ +@AlphaComponent +abstract class Model[M <: Model[M]] extends Transformer { + /** + * The parent estimator that produced this model. + */ + val parent: Estimator[M] + + /** + * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model. + */ + val fittingParamMap: ParamMap +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala new file mode 100644 index 0000000000000000000000000000000000000000..e545df1e37b9c66c68ff8726a226c03c93b31a6f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.collection.mutable.ListBuffer + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.{Params, Param, ParamMap} +import org.apache.spark.sql.{SchemaRDD, StructType} + +/** + * :: AlphaComponent :: + * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]]. + */ +@AlphaComponent +abstract class PipelineStage extends Serializable with Logging { + + /** + * Derives the output schema from the input schema and parameters. + */ + private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType + + /** + * Derives the output schema from the input schema and parameters, optionally with logging. + */ + protected def transformSchema( + schema: StructType, + paramMap: ParamMap, + logging: Boolean): StructType = { + if (logging) { + logDebug(s"Input schema: ${schema.json}") + } + val outputSchema = transformSchema(schema, paramMap) + if (logging) { + logDebug(s"Expected output schema: ${outputSchema.json}") + } + outputSchema + } +} + +/** + * :: AlphaComponent :: + * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each + * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline.fit]] is called, the + * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator.fit]] method will + * be called on the input dataset to fit a model. Then the model, which is a transformer, will be + * used to transform the dataset as the input to the next stage. If a stage is a [[Transformer]], + * its [[Transformer.transform]] method will be called to produce the dataset for the next stage. + * The fitted model from a [[Pipeline]] is an [[PipelineModel]], which consists of fitted models and + * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as + * an identity transformer. + */ +@AlphaComponent +class Pipeline extends Estimator[PipelineModel] { + + /** param for pipeline stages */ + val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") + def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } + def getStages: Array[PipelineStage] = get(stages) + + /** + * Fits the pipeline to the input dataset with additional parameters. If a stage is an + * [[Estimator]], its [[Estimator.fit]] method will be called on the input dataset to fit a model. + * Then the model, which is a transformer, will be used to transform the dataset as the input to + * the next stage. If a stage is a [[Transformer]], its [[Transformer.transform]] method will be + * called to produce the dataset for the next stage. The fitted model from a [[Pipeline]] is an + * [[PipelineModel]], which consists of fitted models and transformers, corresponding to the + * pipeline stages. If there are no stages, the output model acts as an identity transformer. + * + * @param dataset input dataset + * @param paramMap parameter map + * @return fitted pipeline + */ + override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + val theStages = map(stages) + // Search for the last estimator. + var indexOfLastEstimator = -1 + theStages.view.zipWithIndex.foreach { case (stage, index) => + stage match { + case _: Estimator[_] => + indexOfLastEstimator = index + case _ => + } + } + var curDataset = dataset + val transformers = ListBuffer.empty[Transformer] + theStages.view.zipWithIndex.foreach { case (stage, index) => + if (index <= indexOfLastEstimator) { + val transformer = stage match { + case estimator: Estimator[_] => + estimator.fit(curDataset, paramMap) + case t: Transformer => + t + case _ => + throw new IllegalArgumentException( + s"Do not support stage $stage of type ${stage.getClass}") + } + curDataset = transformer.transform(curDataset, paramMap) + transformers += transformer + } else { + transformers += stage.asInstanceOf[Transformer] + } + } + + new PipelineModel(this, map, transformers.toArray) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val theStages = map(stages) + require(theStages.toSet.size == theStages.size, + "Cannot have duplicate components in a pipeline.") + theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap)) + } +} + +/** + * :: AlphaComponent :: + * Represents a compiled pipeline. + */ +@AlphaComponent +class PipelineModel private[ml] ( + override val parent: Pipeline, + override val fittingParamMap: ParamMap, + private[ml] val stages: Array[Transformer]) + extends Model[PipelineModel] with Logging { + + /** + * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input + * estimator does not exist in the pipeline. + */ + def getModel[M <: Model[M]](stage: Estimator[M]): M = { + val matched = stages.filter { + case m: Model[_] => m.parent.eq(stage) + case _ => false + } + if (matched.isEmpty) { + throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.") + } else if (matched.size > 1) { + throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.") + } else { + matched.head.asInstanceOf[M] + } + } + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap)) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, paramMap)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala new file mode 100644 index 0000000000000000000000000000000000000000..490e6609ad311011196e8c2e4685adaeb635b7fc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.annotation.varargs +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param._ +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.api.java.JavaSchemaRDD +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.catalyst.types._ + +/** + * :: AlphaComponent :: + * Abstract class for transformers that transform one dataset into another. + */ +@AlphaComponent +abstract class Transformer extends PipelineStage with Params { + + /** + * Transforms the dataset with optional parameters + * @param dataset input dataset + * @param paramPairs optional list of param pairs, overwrite embedded params + * @return transformed dataset + */ + @varargs + def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = { + val map = new ParamMap() + paramPairs.foreach(map.put(_)) + transform(dataset, map) + } + + /** + * Transforms the dataset with provided parameter map as additional parameters. + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD + + // Java-friendly versions of transform. + + /** + * Transforms the dataset with optional parameters. + * @param dataset input datset + * @param paramPairs optional list of param pairs, overwrite embedded params + * @return transformed dataset + */ + @varargs + def transform(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): JavaSchemaRDD = { + transform(dataset.schemaRDD, paramPairs: _*).toJavaSchemaRDD + } + + /** + * Transforms the dataset with provided parameter map as additional parameters. + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + def transform(dataset: JavaSchemaRDD, paramMap: ParamMap): JavaSchemaRDD = { + transform(dataset.schemaRDD, paramMap).toJavaSchemaRDD + } +} + +/** + * Abstract class for transformers that take one input column, apply transformation, and output the + * result as a new column. + */ +private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]] + extends Transformer with HasInputCol with HasOutputCol with Logging { + + def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] + def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T] + + /** + * Creates the transform function using the given param map. The input param map already takes + * account of the embedded param map. So the param values should be determined solely by the input + * param map. + */ + protected def createTransformFunc(paramMap: ParamMap): IN => OUT + + /** + * Validates the input type. Throw an exception if it is invalid. + */ + protected def validateInputType(inputType: DataType): Unit = {} + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + validateInputType(inputType) + if (schema.fieldNames.contains(map(outputCol))) { + throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.") + } + val output = ScalaReflection.schemaFor[OUT] + val outputFields = schema.fields :+ + StructField(map(outputCol), output.dataType, output.nullable) + StructType(outputFields) + } + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val udf = this.createTransformFunc(map) + dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala new file mode 100644 index 0000000000000000000000000000000000000000..85b8899636ca552f69868ffea67a4accd8b02eab --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.storage.StorageLevel + +/** + * :: AlphaComponent :: + * Params for logistic regression. + */ +@AlphaComponent +private[classification] trait LogisticRegressionParams extends Params + with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol + with HasScoreCol with HasPredictionCol { + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param paramMap additional parameters + * @param fitting whether this is in fitting + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean): StructType = { + val map = this.paramMap ++ paramMap + val featuresType = schema(map(featuresCol)).dataType + // TODO: Support casting Array[Double] and Array[Float] to Vector. + require(featuresType.isInstanceOf[VectorUDT], + s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.") + if (fitting) { + val labelType = schema(map(labelCol)).dataType + require(labelType == DoubleType, + s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.") + } + val fieldNames = schema.fieldNames + require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.") + require(!fieldNames.contains(map(predictionCol)), + s"Prediction column ${map(predictionCol)} already exists.") + val outputFields = schema.fields ++ Seq( + StructField(map(scoreCol), DoubleType, false), + StructField(map(predictionCol), DoubleType, false)) + StructType(outputFields) + } +} + +/** + * Logistic regression. + */ +class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams { + + setRegParam(0.1) + setMaxIter(100) + setThreshold(0.5) + + def setRegParam(value: Double): this.type = set(regParam, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) + def setLabelCol(value: String): this.type = set(labelCol, value) + def setThreshold(value: Double): this.type = set(threshold, value) + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setScoreCol(value: String): this.type = set(scoreCol, value) + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr) + .map { case Row(label: Double, features: Vector) => + LabeledPoint(label, features) + }.persist(StorageLevel.MEMORY_AND_DISK) + val lr = new LogisticRegressionWithLBFGS + lr.optimizer + .setRegParam(map(regParam)) + .setNumIterations(map(maxIter)) + val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights) + instances.unpersist() + // copy model params + Params.inheritValues(map, this, lrm) + lrm + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = true) + } +} + +/** + * :: AlphaComponent :: + * Model produced by [[LogisticRegression]]. + */ +@AlphaComponent +class LogisticRegressionModel private[ml] ( + override val parent: LogisticRegression, + override val fittingParamMap: ParamMap, + weights: Vector) + extends Model[LogisticRegressionModel] with LogisticRegressionParams { + + def setThreshold(value: Double): this.type = set(threshold, value) + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setScoreCol(value: String): this.type = set(scoreCol, value) + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = false) + } + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val score: Vector => Double = (v) => { + val margin = BLAS.dot(v, weights) + 1.0 / (1.0 + math.exp(-margin)) + } + val t = map(threshold) + val predict: Double => Double = (score) => { + if (score > t) 1.0 else 0.0 + } + dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol)) + .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..0b0504e036ec98b542c42119bad9057c114d1209 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.sql.{DoubleType, Row, SchemaRDD} + +/** + * :: AlphaComponent :: + * Evaluator for binary classification, which expects two input columns: score and label. + */ +@AlphaComponent +class BinaryClassificationEvaluator extends Evaluator with Params + with HasScoreCol with HasLabelCol { + + /** param for metric name in evaluation */ + val metricName: Param[String] = new Param(this, "metricName", + "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) + def getMetricName: String = get(metricName) + def setMetricName(value: String): this.type = set(metricName, value) + + def setScoreCol(value: String): this.type = set(scoreCol, value) + def setLabelCol(value: String): this.type = set(labelCol, value) + + override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = { + val map = this.paramMap ++ paramMap + + val schema = dataset.schema + val scoreType = schema(map(scoreCol)).dataType + require(scoreType == DoubleType, + s"Score column ${map(scoreCol)} must be double type but found $scoreType") + val labelType = schema(map(labelCol)).dataType + require(labelType == DoubleType, + s"Label column ${map(labelCol)} must be double type but found $labelType") + + import dataset.sqlContext._ + val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr) + .map { case Row(score: Double, label: Double) => + (score, label) + } + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val metric = map(metricName) match { + case "areaUnderROC" => + metrics.areaUnderROC() + case "areaUnderPR" => + metrics.areaUnderPR() + case other => + throw new IllegalArgumentException(s"Does not support metric $other.") + } + metrics.unpersist() + metric + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala new file mode 100644 index 0000000000000000000000000000000000000000..b98b1755a3584b2437b25cea8de010d7d3878fc5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.Vector + +/** + * :: AlphaComponent :: + * Maps a sequence of terms to their term frequencies using the hashing trick. + */ +@AlphaComponent +class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { + + /** number of features */ + val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18)) + def setNumFeatures(value: Int) = set(numFeatures, value) + def getNumFeatures: Int = get(numFeatures) + + override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { + val hashingTF = new feature.HashingTF(paramMap(numFeatures)) + hashingTF.transform + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala new file mode 100644 index 0000000000000000000000000000000000000000..896a6b83b67bf8c63663dedc8c81f9c7b997c8dc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.dsl._ + +/** + * Params for [[StandardScaler]] and [[StandardScalerModel]]. + */ +private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol + +/** + * :: AlphaComponent :: + * Standardizes features by removing the mean and scaling to unit variance using column summary + * statistics on the samples in the training set. + */ +@AlphaComponent +class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { + + def setInputCol(value: String): this.type = set(inputCol, value) + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val input = dataset.select(map(inputCol).attr) + .map { case Row(v: Vector) => + v + } + val scaler = new feature.StandardScaler().fit(input) + val model = new StandardScalerModel(this, map, scaler) + Params.inheritValues(map, this, model) + model + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${map(inputCol)} must be a vector column") + require(!schema.fieldNames.contains(map(outputCol)), + s"Output column ${map(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false) + StructType(outputFields) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[StandardScaler]]. + */ +@AlphaComponent +class StandardScalerModel private[ml] ( + override val parent: StandardScaler, + override val fittingParamMap: ParamMap, + scaler: feature.StandardScalerModel) + extends Model[StandardScalerModel] with StandardScalerParams { + + def setInputCol(value: String): this.type = set(inputCol, value) + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val scale: (Vector) => Vector = (v) => { + scaler.transform(v) + } + dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol)) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${map(inputCol)} must be a vector column") + require(!schema.fieldNames.contains(map(outputCol)), + s"Output column ${map(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false) + StructType(outputFields) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala new file mode 100644 index 0000000000000000000000000000000000000000..0a6599b64c01172b8f301b08345ea3054729fc0d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.{DataType, StringType} + +/** + * :: AlphaComponent :: + * A tokenizer that converts the input string to lowercase and then splits it by white spaces. + */ +@AlphaComponent +class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { + + protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { + _.toLowerCase.split("\\s") + } + + protected override def validateInputType(inputType: DataType): Unit = { + require(inputType == StringType, s"Input type must be string type but got $inputType.") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..00d9c802e930d66b49171260f7c852c14545b96c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * assemble and configure practical machine learning pipelines. + */ +@AlphaComponent +package org.apache.spark.ml; + +import org.apache.spark.annotation.AlphaComponent; diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala new file mode 100644 index 0000000000000000000000000000000000000000..51cd48c90432a8dab8f6748383b5d62b992da265 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * assemble and configure practical machine learning pipelines. + */ +package object ml diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala new file mode 100644 index 0000000000000000000000000000000000000000..8fd46aef4b99d3a27d61c39f213281dbe424ffd2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +import java.lang.reflect.Modifier + +import org.apache.spark.annotation.AlphaComponent + +import scala.annotation.varargs +import scala.collection.mutable + +import org.apache.spark.ml.Identifiable + +/** + * :: AlphaComponent :: + * A param with self-contained documentation and optionally default value. Primitive-typed param + * should use the specialized versions, which are more friendly to Java users. + * + * @param parent parent object + * @param name param name + * @param doc documentation + * @tparam T param value type + */ +@AlphaComponent +class Param[T] ( + val parent: Params, + val name: String, + val doc: String, + val defaultValue: Option[T] = None) + extends Serializable { + + /** + * Creates a param pair with the given value (for Java). + */ + def w(value: T): ParamPair[T] = this -> value + + /** + * Creates a param pair with the given value (for Scala). + */ + def ->(value: T): ParamPair[T] = ParamPair(this, value) + + override def toString: String = { + if (defaultValue.isDefined) { + s"$name: $doc (default: ${defaultValue.get})" + } else { + s"$name: $doc" + } + } +} + +// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... + +/** Specialized version of [[Param[Double]]] for Java. */ +class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None) + extends Param[Double](parent, name, doc, defaultValue) { + + override def w(value: Double): ParamPair[Double] = super.w(value) +} + +/** Specialized version of [[Param[Int]]] for Java. */ +class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None) + extends Param[Int](parent, name, doc, defaultValue) { + + override def w(value: Int): ParamPair[Int] = super.w(value) +} + +/** Specialized version of [[Param[Float]]] for Java. */ +class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None) + extends Param[Float](parent, name, doc, defaultValue) { + + override def w(value: Float): ParamPair[Float] = super.w(value) +} + +/** Specialized version of [[Param[Long]]] for Java. */ +class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None) + extends Param[Long](parent, name, doc, defaultValue) { + + override def w(value: Long): ParamPair[Long] = super.w(value) +} + +/** Specialized version of [[Param[Boolean]]] for Java. */ +class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None) + extends Param[Boolean](parent, name, doc, defaultValue) { + + override def w(value: Boolean): ParamPair[Boolean] = super.w(value) +} + +/** + * A param amd its value. + */ +case class ParamPair[T](param: Param[T], value: T) + +/** + * :: AlphaComponent :: + * Trait for components that take parameters. This also provides an internal param map to store + * parameter values attached to the instance. + */ +@AlphaComponent +trait Params extends Identifiable with Serializable { + + /** Returns all params. */ + def params: Array[Param[_]] = { + val methods = this.getClass.getMethods + methods.filter { m => + Modifier.isPublic(m.getModifiers) && + classOf[Param[_]].isAssignableFrom(m.getReturnType) && + m.getParameterTypes.isEmpty + }.sortBy(_.getName) + .map(m => m.invoke(this).asInstanceOf[Param[_]]) + } + + /** + * Validates parameter values stored internally plus the input parameter map. + * Raises an exception if any parameter is invalid. + */ + def validate(paramMap: ParamMap): Unit = {} + + /** + * Validates parameter values stored internally. + * Raise an exception if any parameter value is invalid. + */ + def validate(): Unit = validate(ParamMap.empty) + + /** + * Returns the documentation of all params. + */ + def explainParams(): String = params.mkString("\n") + + /** Checks whether a param is explicitly set. */ + def isSet(param: Param[_]): Boolean = { + require(param.parent.eq(this)) + paramMap.contains(param) + } + + /** Gets a param by its name. */ + private[ml] def getParam(paramName: String): Param[Any] = { + val m = this.getClass.getMethod(paramName) + assert(Modifier.isPublic(m.getModifiers) && + classOf[Param[_]].isAssignableFrom(m.getReturnType) && + m.getParameterTypes.isEmpty) + m.invoke(this).asInstanceOf[Param[Any]] + } + + /** + * Sets a parameter in the embedded param map. + */ + private[ml] def set[T](param: Param[T], value: T): this.type = { + require(param.parent.eq(this)) + paramMap.put(param.asInstanceOf[Param[Any]], value) + this + } + + /** + * Gets the value of a parameter in the embedded param map. + */ + private[ml] def get[T](param: Param[T]): T = { + require(param.parent.eq(this)) + paramMap(param) + } + + /** + * Internal param map. + */ + protected val paramMap: ParamMap = ParamMap.empty +} + +private[ml] object Params { + + /** + * Copies parameter values from the parent estimator to the child model it produced. + * @param paramMap the param map that holds parameters of the parent + * @param parent the parent estimator + * @param child the child model + */ + def inheritValues[E <: Params, M <: E]( + paramMap: ParamMap, + parent: E, + child: M): Unit = { + parent.params.foreach { param => + if (paramMap.contains(param)) { + child.set(child.getParam(param.name), paramMap(param)) + } + } + } +} + +/** + * :: AlphaComponent :: + * A param to value map. + */ +@AlphaComponent +class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { + + /** + * Creates an empty param map. + */ + def this() = this(mutable.Map.empty[Param[Any], Any]) + + /** + * Puts a (param, value) pair (overwrites if the input param exists). + */ + def put[T](param: Param[T], value: T): this.type = { + map(param.asInstanceOf[Param[Any]]) = value + this + } + + /** + * Puts a list of param pairs (overwrites if the input params exists). + */ + def put(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + put(p.param.asInstanceOf[Param[Any]], p.value) + } + this + } + + /** + * Optionally returns the value associated with a param or its default. + */ + def get[T](param: Param[T]): Option[T] = { + map.get(param.asInstanceOf[Param[Any]]) + .orElse(param.defaultValue) + .asInstanceOf[Option[T]] + } + + /** + * Gets the value of the input param or its default value if it does not exist. + * Raises a NoSuchElementException if there is no value associated with the input param. + */ + def apply[T](param: Param[T]): T = { + val value = get(param) + if (value.isDefined) { + value.get + } else { + throw new NoSuchElementException(s"Cannot find param ${param.name}.") + } + } + + /** + * Checks whether a parameter is explicitly specified. + */ + def contains(param: Param[_]): Boolean = { + map.contains(param.asInstanceOf[Param[Any]]) + } + + /** + * Filters this param map for the given parent. + */ + def filter(parent: Params): ParamMap = { + val filtered = map.filterKeys(_.parent == parent) + new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]]) + } + + /** + * Make a copy of this param map. + */ + def copy: ParamMap = new ParamMap(map.clone()) + + override def toString: String = { + map.map { case (param, value) => + s"\t${param.parent.uid}-${param.name}: $value" + }.mkString("{\n", ",\n", "\n}") + } + + /** + * Returns a new param map that contains parameters in this map and the given map, + * where the latter overwrites this if there exists conflicts. + */ + def ++(other: ParamMap): ParamMap = { + new ParamMap(this.map ++ other.map) + } + + + /** + * Adds all parameters from the input param map into this param map. + */ + def ++=(other: ParamMap): this.type = { + this.map ++= other.map + this + } + + /** + * Converts this param map to a sequence of param pairs. + */ + def toSeq: Seq[ParamPair[_]] = { + map.toSeq.map { case (param, value) => + ParamPair(param, value) + } + } +} + +object ParamMap { + + /** + * Returns an empty param map. + */ + def empty: ParamMap = new ParamMap() + + /** + * Constructs a param map by specifying its entries. + */ + @varargs + def apply(paramPairs: ParamPair[_]*): ParamMap = { + new ParamMap().put(paramPairs: _*) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala new file mode 100644 index 0000000000000000000000000000000000000000..ef141d3eb2b0686a3379bd3a56e73899978c4582 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +private[ml] trait HasRegParam extends Params { + /** param for regularization parameter */ + val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + def getRegParam: Double = get(regParam) +} + +private[ml] trait HasMaxIter extends Params { + /** param for max number of iterations */ + val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + def getMaxIter: Int = get(maxIter) +} + +private[ml] trait HasFeaturesCol extends Params { + /** param for features column name */ + val featuresCol: Param[String] = + new Param(this, "featuresCol", "features column name", Some("features")) + def getFeaturesCol: String = get(featuresCol) +} + +private[ml] trait HasLabelCol extends Params { + /** param for label column name */ + val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label")) + def getLabelCol: String = get(labelCol) +} + +private[ml] trait HasScoreCol extends Params { + /** param for score column name */ + val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score")) + def getScoreCol: String = get(scoreCol) +} + +private[ml] trait HasPredictionCol extends Params { + /** param for prediction column name */ + val predictionCol: Param[String] = + new Param(this, "predictionCol", "prediction column name", Some("prediction")) + def getPredictionCol: String = get(predictionCol) +} + +private[ml] trait HasThreshold extends Params { + /** param for threshold in (binary) prediction */ + val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") + def getThreshold: Double = get(threshold) +} + +private[ml] trait HasInputCol extends Params { + /** param for input column name */ + val inputCol: Param[String] = new Param(this, "inputCol", "input column name") + def getInputCol: String = get(inputCol) +} + +private[ml] trait HasOutputCol extends Params { + /** param for output column name */ + val outputCol: Param[String] = new Param(this, "outputCol", "output column name") + def getOutputCol: String = get(outputCol) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala new file mode 100644 index 0000000000000000000000000000000000000000..194b9bfd9a9e64c779d4e6c86bca9c5cafc1e154 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import com.github.fommil.netlib.F2jBLAS + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.{SchemaRDD, StructType} + +/** + * Params for [[CrossValidator]] and [[CrossValidatorModel]]. + */ +private[ml] trait CrossValidatorParams extends Params { + /** param for the estimator to be cross-validated */ + val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") + def getEstimator: Estimator[_] = get(estimator) + + /** param for estimator param maps */ + val estimatorParamMaps: Param[Array[ParamMap]] = + new Param(this, "estimatorParamMaps", "param maps for the estimator") + def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps) + + /** param for the evaluator for selection */ + val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") + def getEvaluator: Evaluator = get(evaluator) + + /** param for number of folds for cross validation */ + val numFolds: IntParam = + new IntParam(this, "numFolds", "number of folds for cross validation", Some(3)) + def getNumFolds: Int = get(numFolds) +} + +/** + * :: AlphaComponent :: + * K-fold cross validation. + */ +@AlphaComponent +class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging { + + private val f2jBLAS = new F2jBLAS + + def setEstimator(value: Estimator[_]): this.type = set(estimator, value) + def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) + def setEvaluator(value: Evaluator): this.type = set(evaluator, value) + def setNumFolds(value: Int): this.type = set(numFolds, value) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = { + val map = this.paramMap ++ paramMap + val schema = dataset.schema + transformSchema(dataset.schema, paramMap, logging = true) + val sqlCtx = dataset.sqlContext + val est = map(estimator) + val eval = map(evaluator) + val epm = map(estimatorParamMaps) + val numModels = epm.size + val metrics = new Array[Double](epm.size) + val splits = MLUtils.kFold(dataset, map(numFolds), 0) + splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => + val trainingDataset = sqlCtx.applySchema(training, schema).cache() + val validationDataset = sqlCtx.applySchema(validation, schema).cache() + // multi-model training + logDebug(s"Train split $splitIndex with multiple sets of parameters.") + val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + var i = 0 + while (i < numModels) { + val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map) + logDebug(s"Got metric $metric for model trained with ${epm(i)}.") + metrics(i) += metric + i += 1 + } + } + f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1) + logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") + val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) + logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + logInfo(s"Best cross-validation metric: $bestMetric.") + val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] + val cvModel = new CrossValidatorModel(this, map, bestModel) + Params.inheritValues(map, this, cvModel) + cvModel + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + map(estimator).transformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model from k-fold cross validation. + */ +@AlphaComponent +class CrossValidatorModel private[ml] ( + override val parent: CrossValidator, + override val fittingParamMap: ParamMap, + val bestModel: Model[_]) + extends Model[CrossValidatorModel] with CrossValidatorParams { + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + bestModel.transform(dataset, paramMap) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + bestModel.transformSchema(schema, paramMap) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala new file mode 100644 index 0000000000000000000000000000000000000000..dafe73d82c00a3dc8a76cc748651c7486288aaa0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import scala.annotation.varargs +import scala.collection.mutable + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param._ + +/** + * :: AlphaComponent :: + * Builder for a param grid used in grid search-based model selection. + */ +@AlphaComponent +class ParamGridBuilder { + + private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] + + /** + * Sets the given parameters in this grid to fixed values. + */ + def baseOn(paramMap: ParamMap): this.type = { + baseOn(paramMap.toSeq: _*) + this + } + + /** + * Sets the given parameters in this grid to fixed values. + */ + @varargs + def baseOn(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + addGrid(p.param.asInstanceOf[Param[Any]], Seq(p.value)) + } + this + } + + /** + * Adds a param with multiple values (overwrites if the input param exists). + */ + def addGrid[T](param: Param[T], values: Iterable[T]): this.type = { + paramGrid.put(param, values) + this + } + + // specialized versions of addGrid for Java. + + /** + * Adds a double param with multiple values. + */ + def addGrid(param: DoubleParam, values: Array[Double]): this.type = { + addGrid[Double](param, values) + } + + /** + * Adds a int param with multiple values. + */ + def addGrid(param: IntParam, values: Array[Int]): this.type = { + addGrid[Int](param, values) + } + + /** + * Adds a float param with multiple values. + */ + def addGrid(param: FloatParam, values: Array[Float]): this.type = { + addGrid[Float](param, values) + } + + /** + * Adds a long param with multiple values. + */ + def addGrid(param: LongParam, values: Array[Long]): this.type = { + addGrid[Long](param, values) + } + + /** + * Adds a boolean param with true and false. + */ + def addGrid(param: BooleanParam): this.type = { + addGrid[Boolean](param, Array(true, false)) + } + + /** + * Builds and returns all combinations of parameters specified by the param grid. + */ + def build(): Array[ParamMap] = { + var paramMaps = Array(new ParamMap) + paramGrid.foreach { case (param, values) => + val newParamMaps = values.flatMap { v => + paramMaps.map(_.copy.put(param.asInstanceOf[Param[Any]], v)) + } + paramMaps = newParamMaps.toArray + } + paramMaps + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 54ee930d610031bc2cdf36501c0aed34cd9700ae..89539e600f48c41b02d43a9302c307b2e1f6906d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -25,7 +25,7 @@ import org.apache.spark.Logging /** * BLAS routines for MLlib's vectors and matrices. */ -private[mllib] object BLAS extends Serializable with Logging { +private[spark] object BLAS extends Serializable with Logging { @transient private var _f2jBLAS: NetlibBLAS = _ @transient private var _nativeBLAS: NetlibBLAS = _ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index ac217edc619ab6ee278a9a13ec4f1f84b1b4bfee..9fccd6341ba7da127e9d80e853af10863d274370 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -115,6 +115,9 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def deserialize(datum: Any): Vector = { datum match { + // TODO: something wrong with UDT serialization + case v: Vector => + v case row: Row => require(row.length == 4, s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 17c753c56681fbbd155fbb739a5c87f96989eb58..2067b36f246b3e33ef6d476c9a882af8ee48b99a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.regression +import scala.beans.BeanInfo + import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException @@ -27,6 +29,7 @@ import org.apache.spark.SparkException * @param label Label for this data point. * @param features List of features for this data point. */ +@BeanInfo case class LabeledPoint(label: Double, features: Vector) { override def toString: String = { "(%s,%s)".format(label, features) diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..42846677ed2850695b1dee2442ef8a0c5b1e0646 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +/** + * Test Pipeline construction and fitting in Java. + */ +public class JavaPipelineSuite { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaPipelineSuite"); + jsql = new JavaSQLContext(jsc); + JavaRDD<LabeledPoint> points = + jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); + dataset = jsql.applySchema(points, LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void pipeline() { + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + LogisticRegression lr = new LogisticRegression() + .setFeaturesCol("scaledFeatures"); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {scaler, lr}); + PipelineModel model = pipeline.fit(dataset); + model.transform(dataset).registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..76eb7f00329f2bcf964497a04b703d5f4d24f7e1 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +public class JavaLogisticRegressionSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + jsql = new JavaSQLContext(jsc); + List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42); + dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void logisticRegression() { + LogisticRegression lr = new LogisticRegression(); + LogisticRegressionModel model = lr.fit(dataset); + model.transform(dataset).registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } + + @Test + public void logisticRegressionWithSetters() { + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0); + LogisticRegressionModel model = lr.fit(dataset); + model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold + .registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } + + @Test + public void logisticRegressionFitWithVarargs() { + LogisticRegression lr = new LogisticRegression(); + lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0)); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..a266ebd2071a1fd819c9af8c3ed432bd54a97f15 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +public class JavaCrossValidatorSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); + jsql = new JavaSQLContext(jsc); + List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42); + dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void crossValidationWithLogisticRegression() { + LogisticRegression lr = new LogisticRegression(); + ParamMap[] lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.001, 1000.0}) + .addGrid(lr.maxIter(), new int[] {0, 10}) + .build(); + BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator(); + CrossValidator cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3); + CrossValidatorModel cvModel = cv.fit(dataset); + ParamMap bestParamMap = cvModel.bestModel().fittingParamMap(); + Assert.assertEquals(0.001, bestParamMap.apply(lr.regParam())); + Assert.assertEquals(10, bestParamMap.apply(lr.maxIter())); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..4515084bc7ae95c7ca9b49c32606517966c06d1a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.when +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar.mock + +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.SchemaRDD + +class PipelineSuite extends FunSuite { + + abstract class MyModel extends Model[MyModel] + + test("pipeline") { + val estimator0 = mock[Estimator[MyModel]] + val model0 = mock[MyModel] + val transformer1 = mock[Transformer] + val estimator2 = mock[Estimator[MyModel]] + val model2 = mock[MyModel] + val transformer3 = mock[Transformer] + val dataset0 = mock[SchemaRDD] + val dataset1 = mock[SchemaRDD] + val dataset2 = mock[SchemaRDD] + val dataset3 = mock[SchemaRDD] + val dataset4 = mock[SchemaRDD] + + when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) + when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) + when(model0.parent).thenReturn(estimator0) + when(transformer1.transform(meq(dataset1), any[ParamMap])).thenReturn(dataset2) + when(estimator2.fit(meq(dataset2), any[ParamMap]())).thenReturn(model2) + when(model2.transform(meq(dataset2), any[ParamMap]())).thenReturn(dataset3) + when(model2.parent).thenReturn(estimator2) + when(transformer3.transform(meq(dataset3), any[ParamMap]())).thenReturn(dataset4) + + val pipeline = new Pipeline() + .setStages(Array(estimator0, transformer1, estimator2, transformer3)) + val pipelineModel = pipeline.fit(dataset0) + + assert(pipelineModel.stages.size === 4) + assert(pipelineModel.stages(0).eq(model0)) + assert(pipelineModel.stages(1).eq(transformer1)) + assert(pipelineModel.stages(2).eq(model2)) + assert(pipelineModel.stages(3).eq(transformer3)) + + assert(pipelineModel.getModel(estimator0).eq(model0)) + assert(pipelineModel.getModel(estimator2).eq(model2)) + intercept[NoSuchElementException] { + pipelineModel.getModel(mock[Estimator[MyModel]]) + } + val output = pipelineModel.transform(dataset0) + assert(output.eq(dataset4)) + } + + test("pipeline with duplicate stages") { + val estimator = mock[Estimator[MyModel]] + val pipeline = new Pipeline() + .setStages(Array(estimator, estimator)) + val dataset = mock[SchemaRDD] + intercept[IllegalArgumentException] { + pipeline.fit(dataset) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..625af299a540370a7262a94acfc4529c673573e3 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.sql.SchemaRDD + +class LogisticRegressionSuite extends FunSuite with LocalSparkContext { + + import sqlContext._ + + val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2) + + test("logistic regression") { + val lr = new LogisticRegression + val model = lr.fit(dataset) + model.transform(dataset) + .select('label, 'prediction) + .collect() + } + + test("logistic regression with setters") { + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + val model = lr.fit(dataset) + model.transform(dataset, model.threshold -> 0.8) // overwrite threshold + .select('label, 'score, 'prediction) + .collect() + } + + test("logistic regression fit and transform with varargs") { + val lr = new LogisticRegression + val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) + model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") + .select('label, 'probability, 'prediction) + .collect() + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..1ce298761237848a74f6f43fc145cfe8eefabdfb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +import org.scalatest.FunSuite + +class ParamsSuite extends FunSuite { + + val solver = new TestParams() + import solver.{inputCol, maxIter} + + test("param") { + assert(maxIter.name === "maxIter") + assert(maxIter.doc === "max number of iterations") + assert(maxIter.defaultValue.get === 100) + assert(maxIter.parent.eq(solver)) + assert(maxIter.toString === "maxIter: max number of iterations (default: 100)") + assert(inputCol.defaultValue === None) + } + + test("param pair") { + val pair0 = maxIter -> 5 + val pair1 = maxIter.w(5) + val pair2 = ParamPair(maxIter, 5) + for (pair <- Seq(pair0, pair1, pair2)) { + assert(pair.param.eq(maxIter)) + assert(pair.value === 5) + } + } + + test("param map") { + val map0 = ParamMap.empty + + assert(!map0.contains(maxIter)) + assert(map0(maxIter) === maxIter.defaultValue.get) + map0.put(maxIter, 10) + assert(map0.contains(maxIter)) + assert(map0(maxIter) === 10) + + assert(!map0.contains(inputCol)) + intercept[NoSuchElementException] { + map0(inputCol) + } + map0.put(inputCol -> "input") + assert(map0.contains(inputCol)) + assert(map0(inputCol) === "input") + + val map1 = map0.copy + val map2 = ParamMap(maxIter -> 10, inputCol -> "input") + val map3 = new ParamMap() + .put(maxIter, 10) + .put(inputCol, "input") + val map4 = ParamMap.empty ++ map0 + val map5 = ParamMap.empty + map5 ++= map0 + + for (m <- Seq(map1, map2, map3, map4, map5)) { + assert(m.contains(maxIter)) + assert(m(maxIter) === 10) + assert(m.contains(inputCol)) + assert(m(inputCol) === "input") + } + } + + test("params") { + val params = solver.params + assert(params.size === 2) + assert(params(0).eq(inputCol), "params must be ordered by name") + assert(params(1).eq(maxIter)) + assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) + assert(solver.getParam("inputCol").eq(inputCol)) + assert(solver.getParam("maxIter").eq(maxIter)) + intercept[NoSuchMethodException] { + solver.getParam("abc") + } + assert(!solver.isSet(inputCol)) + intercept[IllegalArgumentException] { + solver.validate() + } + solver.validate(ParamMap(inputCol -> "input")) + solver.setInputCol("input") + assert(solver.isSet(inputCol)) + assert(solver.getInputCol === "input") + solver.validate() + intercept[IllegalArgumentException] { + solver.validate(ParamMap(maxIter -> -10)) + } + solver.setMaxIter(-10) + intercept[IllegalArgumentException] { + solver.validate() + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala new file mode 100644 index 0000000000000000000000000000000000000000..1a65883d78a71b324aac1b22aa0f8b82c980d3c8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +/** A subclass of Params for testing. */ +class TestParams extends Params { + + val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100)) + def setMaxIter(value: Int): this.type = { set(maxIter, value); this } + def getMaxIter: Int = get(maxIter) + + val inputCol = new Param[String](this, "inputCol", "input column name") + def setInputCol(value: String): this.type = { set(inputCol, value); this } + def getInputCol: String = get(inputCol) + + override def validate(paramMap: ParamMap) = { + val m = this.paramMap ++ paramMap + require(m(maxIter) >= 0) + require(m.contains(inputCol)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..72a334ae9303ec7e19b415d7b27827c411c92f8e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.scalatest.FunSuite + +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.sql.SchemaRDD + +class CrossValidatorSuite extends FunSuite with LocalSparkContext { + + import sqlContext._ + + val dataset: SchemaRDD = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2) + + test("cross validation with logistic regression") { + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 10)) + .build() + val eval = new BinaryClassificationEvaluator + val cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3) + val cvModel = cv.fit(dataset) + val bestParamMap = cvModel.bestModel.fittingParamMap + assert(bestParamMap(lr.regParam) === 0.001) + assert(bestParamMap(lr.maxIter) === 10) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..20aa100112bfe0f6b2f1b2bed0a4e110330b485d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import scala.collection.mutable + +import org.scalatest.FunSuite + +import org.apache.spark.ml.param.{ParamMap, TestParams} + +class ParamGridBuilderSuite extends FunSuite { + + val solver = new TestParams() + import solver.{inputCol, maxIter} + + test("param grid builder") { + def validateGrid(maps: Array[ParamMap], expected: mutable.Set[(Int, String)]): Unit = { + assert(maps.size === expected.size) + maps.foreach { m => + val tuple = (m(maxIter), m(inputCol)) + assert(expected.contains(tuple)) + expected.remove(tuple) + } + assert(expected.isEmpty) + } + + val maps0 = new ParamGridBuilder() + .baseOn(maxIter -> 10) + .addGrid(inputCol, Array("input0", "input1")) + .build() + val expected0 = mutable.Set( + (10, "input0"), + (10, "input1")) + validateGrid(maps0, expected0) + + val maps1 = new ParamGridBuilder() + .baseOn(ParamMap(maxIter -> 5, inputCol -> "input")) // will be overwritten + .addGrid(maxIter, Array(10, 20)) + .addGrid(inputCol, Array("input0", "input1")) + .build() + val expected1 = mutable.Set( + (10, "input0"), + (20, "input0"), + (10, "input1"), + (20, "input1")) + validateGrid(maps1, expected1) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala index 7857d9e5ee5c4920e8c4cb1b9228ce8a91c00744..4417d66adf0fc2d90c1c55e42d3f2e8ac6fe3dc5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala @@ -17,26 +17,17 @@ package org.apache.spark.mllib.util -import org.scalatest.Suite -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, Suite} -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext trait LocalSparkContext extends BeforeAndAfterAll { self: Suite => - @transient var sc: SparkContext = _ - - override def beforeAll() { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - sc = new SparkContext(conf) - super.beforeAll() - } + @transient val sc = new SparkContext("local", "test") + @transient lazy val sqlContext = new SQLContext(sc) override def afterAll() { - if (sc != null) { - sc.stop() - } + sc.stop() super.afterAll() } }