Skip to content
Snippets Groups Projects
Commit 871764c6 authored by Holden Karau's avatar Holden Karau Committed by Reynold Xin
Browse files

[SPARK-10013] [ML] [JAVA] [TEST] remove java assert from java unit tests

From Jira: We should use assertTrue, etc. instead to make sure the asserts are not ignored in tests.

Author: Holden Karau <holden@pigscanfly.ca>

Closes #8607 from holdenk/SPARK-10013-remove-java-assert-from-java-unit-tests.
parent bca8c072
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,7 @@ import java.lang.Math;
import java.util.List;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
......@@ -63,16 +64,16 @@ public class JavaLogisticRegressionSuite implements Serializable {
@Test
public void logisticRegressionDefaultParams() {
LogisticRegression lr = new LogisticRegression();
assert(lr.getLabelCol().equals("label"));
Assert.assertEquals(lr.getLabelCol(), "label");
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
// Check defaults
assert(model.getThreshold() == 0.5);
assert(model.getFeaturesCol().equals("features"));
assert(model.getPredictionCol().equals("prediction"));
assert(model.getProbabilityCol().equals("probability"));
Assert.assertEquals(0.5, model.getThreshold(), eps);
Assert.assertEquals("features", model.getFeaturesCol());
Assert.assertEquals("prediction", model.getPredictionCol());
Assert.assertEquals("probability", model.getProbabilityCol());
}
@Test
......@@ -85,19 +86,19 @@ public class JavaLogisticRegressionSuite implements Serializable {
.setProbabilityCol("myProbability");
LogisticRegressionModel model = lr.fit(dataset);
LogisticRegression parent = (LogisticRegression) model.parent();
assert(parent.getMaxIter() == 10);
assert(parent.getRegParam() == 1.0);
assert(parent.getThresholds()[0] == 0.4);
assert(parent.getThresholds()[1] == 0.6);
assert(parent.getThreshold() == 0.6);
assert(model.getThreshold() == 0.6);
Assert.assertEquals(10, parent.getMaxIter());
Assert.assertEquals(1.0, parent.getRegParam(), eps);
Assert.assertEquals(0.4, parent.getThresholds()[0], eps);
Assert.assertEquals(0.6, parent.getThresholds()[1], eps);
Assert.assertEquals(0.6, parent.getThreshold(), eps);
Assert.assertEquals(0.6, model.getThreshold(), eps);
// Modify model params, and check that the params worked.
model.setThreshold(1.0);
model.transform(dataset).registerTempTable("predAllZero");
DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
for (Row r: predAllZero.collectAsList()) {
assert(r.getDouble(0) == 0.0);
Assert.assertEquals(0.0, r.getDouble(0), eps);
}
// Call transform with params, and check that the params worked.
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
......@@ -107,17 +108,17 @@ public class JavaLogisticRegressionSuite implements Serializable {
for (Row r: predNotAllZero.collectAsList()) {
if (r.getDouble(0) != 0.0) foundNonZero = true;
}
assert(foundNonZero);
Assert.assertTrue(foundNonZero);
// Call fit() with new params, and check as many params as we can.
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
LogisticRegression parent2 = (LogisticRegression) model2.parent();
assert(parent2.getMaxIter() == 5);
assert(parent2.getRegParam() == 0.1);
assert(parent2.getThreshold() == 0.4);
assert(model2.getThreshold() == 0.4);
assert(model2.getProbabilityCol().equals("theProb"));
Assert.assertEquals(5, parent2.getMaxIter());
Assert.assertEquals(0.1, parent2.getRegParam(), eps);
Assert.assertEquals(0.4, parent2.getThreshold(), eps);
Assert.assertEquals(0.4, model2.getThreshold(), eps);
Assert.assertEquals("theProb", model2.getProbabilityCol());
}
@SuppressWarnings("unchecked")
......@@ -125,18 +126,18 @@ public class JavaLogisticRegressionSuite implements Serializable {
public void logisticRegressionPredictorClassifierMethods() {
LogisticRegression lr = new LogisticRegression();
LogisticRegressionModel model = lr.fit(dataset);
assert(model.numClasses() == 2);
Assert.assertEquals(2, model.numClasses());
model.transform(dataset).registerTempTable("transformed");
DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
for (Row row: trans1.collect()) {
Vector raw = (Vector)row.get(0);
Vector prob = (Vector)row.get(1);
assert(raw.size() == 2);
assert(prob.size() == 2);
Assert.assertEquals(raw.size(), 2);
Assert.assertEquals(prob.size(), 2);
double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
assert(Math.abs(prob.apply(1) - probFromRaw1) < eps);
assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps);
Assert.assertEquals(0, Math.abs(prob.apply(1) - probFromRaw1), eps);
Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
}
DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
......@@ -145,7 +146,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
Vector prob = (Vector)row.get(1);
double probOfPred = prob.apply((int)pred);
for (int i = 0; i < prob.size(); ++i) {
assert(probOfPred >= prob.apply(i));
Assert.assertTrue(probOfPred >= prob.apply(i));
}
}
}
......@@ -156,6 +157,6 @@ public class JavaLogisticRegressionSuite implements Serializable {
LogisticRegressionModel model = lr.fit(dataset);
LogisticRegressionTrainingSummary summary = model.summary();
assert(summary.totalIterations() == summary.objectiveHistory().length);
Assert.assertEquals(summary.totalIterations(), summary.objectiveHistory().length);
}
}
......@@ -23,6 +23,7 @@ import java.util.Arrays;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
......@@ -58,18 +59,18 @@ public class JavaNaiveBayesSuite implements Serializable {
for (Row r : predictionAndLabels.collect()) {
double prediction = r.getAs(0);
double label = r.getAs(1);
assert(prediction == label);
assertEquals(label, prediction, 1E-5);
}
}
@Test
public void naiveBayesDefaultParams() {
NaiveBayes nb = new NaiveBayes();
assert(nb.getLabelCol() == "label");
assert(nb.getFeaturesCol() == "features");
assert(nb.getPredictionCol() == "prediction");
assert(nb.getSmoothing() == 1.0);
assert(nb.getModelType() == "multinomial");
assertEquals("label", nb.getLabelCol());
assertEquals("features", nb.getFeaturesCol());
assertEquals("prediction", nb.getPredictionCol());
assertEquals(1.0, nb.getSmoothing(), 1E-5);
assertEquals("multinomial", nb.getModelType());
}
@Test
......
......@@ -60,7 +60,7 @@ public class JavaLinearRegressionSuite implements Serializable {
@Test
public void linearRegressionDefaultParams() {
LinearRegression lr = new LinearRegression();
assert(lr.getLabelCol().equals("label"));
assertEquals("label", lr.getLabelCol());
LinearRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction");
......
......@@ -80,10 +80,10 @@ public class JavaMatricesSuite implements Serializable {
assertArrayEquals(sd.toArray(), s.toArray(), 0.0);
assertArrayEquals(s.toArray(), ss.toArray(), 0.0);
assertArrayEquals(s.values(), ss.values(), 0.0);
assert(s.values().length == 2);
assert(ss.values().length == 2);
assert(s.colPtrs().length == 4);
assert(ss.colPtrs().length == 4);
assertEquals(2, s.values().length);
assertEquals(2, ss.values().length);
assertEquals(4, s.colPtrs().length);
assertEquals(4, ss.colPtrs().length);
}
@Test
......@@ -137,27 +137,27 @@ public class JavaMatricesSuite implements Serializable {
Matrix deHorz2 = Matrices.horzcat(new Matrix[]{spMat1, deMat2});
Matrix deHorz3 = Matrices.horzcat(new Matrix[]{deMat1, spMat2});
assert(deHorz1.numRows() == 3);
assert(deHorz2.numRows() == 3);
assert(deHorz3.numRows() == 3);
assert(spHorz.numRows() == 3);
assert(deHorz1.numCols() == 5);
assert(deHorz2.numCols() == 5);
assert(deHorz3.numCols() == 5);
assert(spHorz.numCols() == 5);
assertEquals(3, deHorz1.numRows());
assertEquals(3, deHorz2.numRows());
assertEquals(3, deHorz3.numRows());
assertEquals(3, spHorz.numRows());
assertEquals(5, deHorz1.numCols());
assertEquals(5, deHorz2.numCols());
assertEquals(5, deHorz3.numCols());
assertEquals(5, spHorz.numCols());
Matrix spVert = Matrices.vertcat(new Matrix[]{spMat1, spMat3});
Matrix deVert1 = Matrices.vertcat(new Matrix[]{deMat1, deMat3});
Matrix deVert2 = Matrices.vertcat(new Matrix[]{spMat1, deMat3});
Matrix deVert3 = Matrices.vertcat(new Matrix[]{deMat1, spMat3});
assert(deVert1.numRows() == 5);
assert(deVert2.numRows() == 5);
assert(deVert3.numRows() == 5);
assert(spVert.numRows() == 5);
assert(deVert1.numCols() == 2);
assert(deVert2.numCols() == 2);
assert(deVert3.numCols() == 2);
assert(spVert.numCols() == 2);
assertEquals(5, deVert1.numRows());
assertEquals(5, deVert2.numRows());
assertEquals(5, deVert3.numRows());
assertEquals(5, spVert.numRows());
assertEquals(2, deVert1.numCols());
assertEquals(2, deVert2.numCols());
assertEquals(2, deVert3.numCols());
assertEquals(2, spVert.numCols());
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment