diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 0f6d5809e098f1997f5733ab25b0ca205ee3fa4e..c53475818395f19260c9c21c765a4e9fe8fb7dbc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -32,12 +32,12 @@ import org.apache.spark.util.Utils * :: Experimental :: * Maps a sequence of terms to their term frequencies using the hashing trick. * - * @param numFeatures number of features (default: 1000000) + * @param numFeatures number of features (default: 2^20^) */ @Experimental class HashingTF(val numFeatures: Int) extends Serializable { - def this() = this(1000000) + def this() = this(1 << 20) /** * Returns the index of the input term. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index ea9fd0a80d8e0d92e32ce655a1a856845650bd5c..3afb47767281ccb400deabeece81907fd276eb69 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -19,11 +19,11 @@ package org.apache.spark.mllib.feature import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} /** - * :: DeveloperApi :: + * :: Experimental :: * Normalizes samples individually to unit L^p^ norm * * For any 1 <= p < Double.PositiveInfinity, normalizes samples using @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} * * @param p Normalization in L^p^ space, p = 2 by default. */ -@DeveloperApi +@Experimental class Normalizer(p: Double) extends VectorTransformer { def this() = this(2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index cc2d7579c2901213aedf2e238eaeb680bd28a1c9..e6c9f8f67df6347e26896eaa3b5c2cf6f7d8d67b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -19,14 +19,14 @@ package org.apache.spark.mllib.feature import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD /** - * :: DeveloperApi :: + * :: Experimental :: * Standardizes features by removing the mean and scaling to unit variance using column summary * statistics on the samples in the training set. * @@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD * dense output, so this does not work on sparse input and will raise an exception. * @param withStd True by default. Scales the data to unit standard deviation. */ -@DeveloperApi +@Experimental class StandardScaler(withMean: Boolean, withStd: Boolean) extends VectorTransformer { def this() = this(false, true) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 3bf44ad7c44e36af30c9330258851962ccffb4a8..395037e1ec47c6bd4746ac3bb72bdcbaa981ef0e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -17,6 +17,9 @@ package org.apache.spark.mllib.feature +import java.lang.{Iterable => JavaIterable} + +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -25,6 +28,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ @@ -239,7 +243,7 @@ class Word2Vec extends Serializable with Logging { a += 1 } } - + /** * Computes the vector representation of each word in vocabulary. * @param dataset an RDD of words @@ -369,11 +373,22 @@ class Word2Vec extends Serializable with Logging { new Word2VecModel(word2VecMap.toMap) } + + /** + * Computes the vector representation of each word in vocabulary (Java version). + * @param dataset a JavaRDD of words + * @return a Word2VecModel + */ + def fit[S <: JavaIterable[String]](dataset: JavaRDD[S]): Word2VecModel = { + fit(dataset.rdd.map(_.asScala)) + } } /** -* Word2Vec model + * :: Experimental :: + * Word2Vec model */ +@Experimental class Word2VecModel private[mllib] ( private val model: Map[String, Array[Float]]) extends Serializable { diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..fb7afe8c6434bbdd2751a6e2fcb11f5c5f3c94c8 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaWord2VecSuite.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple2; + +import com.google.common.collect.Lists; +import com.google.common.base.Strings; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaWord2VecSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaWord2VecSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + @SuppressWarnings("unchecked") + public void word2Vec() { + // The tests are to check Java compatibility. + String sentence = Strings.repeat("a b ", 100) + Strings.repeat("a c ", 10); + List<String> words = Lists.newArrayList(sentence.split(" ")); + List<List<String>> localDoc = Lists.newArrayList(words, words); + JavaRDD<List<String>> doc = sc.parallelize(localDoc); + Word2Vec word2vec = new Word2Vec() + .setVectorSize(10) + .setSeed(42L); + Word2VecModel model = word2vec.fit(doc); + Tuple2<String, Object>[] syms = model.findSynonyms("a", 2); + Assert.assertEquals(2, syms.length); + Assert.assertEquals("b", syms[0]._1()); + Assert.assertEquals("c", syms[1]._1()); + } +}