Skip to content
Snippets Groups Projects
Commit d91967e1 authored by Holden Karau's avatar Holden Karau Committed by Xiangrui Meng
Browse files

[SPARK-10763] [ML] [JAVA] [TEST] Update Java MLLIB/ML tests to use simplified...

[SPARK-10763] [ML] [JAVA] [TEST] Update Java MLLIB/ML tests to use simplified dataframe construction

As introduced in https://issues.apache.org/jira/browse/SPARK-10630 we now have an easier way to create dataframes from local Java lists. Lets update the tests to use those.

Author: Holden Karau <holden@pigscanfly.ca>

Closes #8886 from holdenk/SPARK-10763-update-java-mllib-ml-tests-to-use-simplified-dataframe-construction.
parent 758c9d25
No related branches found
No related tags found
No related merge requests found
Showing
with 42 additions and 39 deletions
......@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.junit.After;
import org.junit.Before;
......@@ -75,21 +76,20 @@ public class JavaNaiveBayesSuite implements Serializable {
@Test
public void testNaiveBayes() {
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
List<Row> data = Arrays.asList(
RowFactory.create(0.0, Vectors.dense(1.0, 0.0, 0.0)),
RowFactory.create(0.0, Vectors.dense(2.0, 0.0, 0.0)),
RowFactory.create(1.0, Vectors.dense(0.0, 1.0, 0.0)),
RowFactory.create(1.0, Vectors.dense(0.0, 2.0, 0.0)),
RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 1.0)),
RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0))
));
RowFactory.create(2.0, Vectors.dense(0.0, 0.0, 2.0)));
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features", new VectorUDT(), false, Metadata.empty())
});
DataFrame dataset = jsql.createDataFrame(jrdd, schema);
DataFrame dataset = jsql.createDataFrame(data, schema);
NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
NaiveBayesModel model = nb.fit(dataset);
......
......@@ -55,16 +55,16 @@ public class JavaBucketizerSuite {
public void bucketizerTest() {
double[] splits = {-0.5, 0.0, 0.5};
JavaRDD<Row> data = jsc.parallelize(Arrays.asList(
RowFactory.create(-0.5),
RowFactory.create(-0.3),
RowFactory.create(0.0),
RowFactory.create(0.2)
));
StructType schema = new StructType(new StructField[] {
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
DataFrame dataset = jsql.createDataFrame(data, schema);
DataFrame dataset = jsql.createDataFrame(
Arrays.asList(
RowFactory.create(-0.5),
RowFactory.create(-0.3),
RowFactory.create(0.0),
RowFactory.create(0.2)),
schema);
Bucketizer bucketizer = new Bucketizer()
.setInputCol("feature")
......
......@@ -57,12 +57,11 @@ public class JavaDCTSuite {
@Test
public void javaCompatibilityTest() {
double[] input = new double[] {1D, 2D, 3D, 4D};
JavaRDD<Row> data = jsc.parallelize(Arrays.asList(
RowFactory.create(Vectors.dense(input))
));
DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{
new StructField("vec", (new VectorUDT()), false, Metadata.empty())
}));
DataFrame dataset = jsql.createDataFrame(
Arrays.asList(RowFactory.create(Vectors.dense(input))),
new StructType(new StructField[]{
new StructField("vec", (new VectorUDT()), false, Metadata.empty())
}));
double[] expectedResult = input.clone();
(new DoubleDCT_1D(input.length)).forward(expectedResult, true);
......
......@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
import java.util.List;
import org.junit.After;
import org.junit.Assert;
......@@ -55,17 +56,17 @@ public class JavaHashingTFSuite {
@Test
public void hashingTF() {
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
List<Row> data = Arrays.asList(
RowFactory.create(0.0, "Hi I heard about Spark"),
RowFactory.create(0.0, "I wish Java could use case classes"),
RowFactory.create(1.0, "Logistic regression models are neat")
));
);
StructType schema = new StructType(new StructField[]{
new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
DataFrame sentenceData = jsql.createDataFrame(jrdd, schema);
DataFrame sentenceData = jsql.createDataFrame(data, schema);
Tokenizer tokenizer = new Tokenizer()
.setInputCol("sentence")
.setOutputCol("words");
......
......@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
import java.util.List;
import org.junit.After;
import org.junit.Assert;
......@@ -60,7 +61,7 @@ public class JavaPolynomialExpansionSuite {
.setOutputCol("polyFeatures")
.setDegree(3);
JavaRDD<Row> data = jsc.parallelize(Arrays.asList(
List<Row> data = Arrays.asList(
RowFactory.create(
Vectors.dense(-2.0, 2.3),
Vectors.dense(-2.0, 4.0, -8.0, 2.3, -4.6, 9.2, 5.29, -10.58, 12.17)
......@@ -70,7 +71,7 @@ public class JavaPolynomialExpansionSuite {
Vectors.dense(0.6, -1.1),
Vectors.dense(0.6, 0.36, 0.216, -1.1, -0.66, -0.396, 1.21, 0.726, -1.331)
)
));
);
StructType schema = new StructType(new StructField[] {
new StructField("features", new VectorUDT(), false, Metadata.empty()),
......
......@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
import java.util.List;
import org.junit.After;
import org.junit.Before;
......@@ -58,14 +59,14 @@ public class JavaStopWordsRemoverSuite {
.setInputCol("raw")
.setOutputCol("filtered");
JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(
List<Row> data = Arrays.asList(
RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")),
RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb"))
));
);
StructType schema = new StructType(new StructField[] {
new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
});
DataFrame dataset = jsql.createDataFrame(rdd, schema);
DataFrame dataset = jsql.createDataFrame(data, schema);
remover.transform(dataset).collect();
}
......
......@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
import java.util.List;
import org.junit.After;
import org.junit.Assert;
......@@ -56,9 +57,9 @@ public class JavaStringIndexerSuite {
createStructField("id", IntegerType, false),
createStructField("label", StringType, false)
});
JavaRDD<Row> rdd = jsc.parallelize(
Arrays.asList(c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c")));
DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
List<Row> data = Arrays.asList(
c(0, "a"), c(1, "b"), c(2, "c"), c(3, "a"), c(4, "a"), c(5, "c"));
DataFrame dataset = sqlContext.createDataFrame(data, schema);
StringIndexer indexer = new StringIndexer()
.setInputCol("label")
......
......@@ -65,8 +65,7 @@ public class JavaVectorAssemblerSuite {
Row row = RowFactory.create(
0, 0.0, Vectors.dense(1.0, 2.0), "a",
Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(row));
DataFrame dataset = sqlContext.createDataFrame(rdd, schema);
DataFrame dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[] {"x", "y", "z", "n"})
.setOutputCol("features");
......
......@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
import java.util.List;
import org.junit.After;
import org.junit.Assert;
......@@ -63,12 +64,12 @@ public class JavaVectorSlicerSuite {
};
AttributeGroup group = new AttributeGroup("userFeatures", attrs);
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
List<Row> data = Arrays.asList(
RowFactory.create(Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0, 2.3})),
RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
));
);
DataFrame dataset = jsql.createDataFrame(jrdd, (new StructType()).add(group.toStructField()));
DataFrame dataset = jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));
VectorSlicer vectorSlicer = new VectorSlicer()
.setInputCol("userFeatures").setOutputCol("features");
......
......@@ -51,15 +51,15 @@ public class JavaWord2VecSuite {
@Test
public void testJavaWord2Vec() {
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))
));
StructType schema = new StructType(new StructField[]{
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
});
DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema);
DataFrame documentDF = sqlContext.createDataFrame(
Arrays.asList(
RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
RowFactory.create(Arrays.asList("Logistic regression models are neat".split(" ")))),
schema);
Word2Vec word2Vec = new Word2Vec()
.setInputCol("text")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment