Skip to content
Snippets Groups Projects
Commit 6077e3ef authored by Felix Cheung's avatar Felix Cheung Committed by Felix Cheung
Browse files

[SPARK-21801][SPARKR][TEST] unit test randomly fail with randomforest

## What changes were proposed in this pull request?

fix the random seed to eliminate variability

## How was this patch tested?

jenkins, appveyor, lots more jenkins

Author: Felix Cheung <felixcheung_m@hotmail.com>

Closes #19018 from felixcheung/rrftest.
parent 6327ea57
No related branches found
No related tags found
No related merge requests found
......@@ -66,7 +66,7 @@ test_that("spark.gbt", {
# label must be binary - GBTClassifier currently only supports binary classification.
iris2 <- iris[iris$Species != "virginica", ]
data <- suppressWarnings(createDataFrame(iris2))
model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification")
model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification", seed = 12)
stats <- summary(model)
expect_equal(stats$numFeatures, 2)
expect_equal(stats$numTrees, 20)
......@@ -94,7 +94,7 @@ test_that("spark.gbt", {
iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1)
df <- suppressWarnings(createDataFrame(iris2))
m <- spark.gbt(df, NumericSpecies ~ ., type = "classification")
m <- spark.gbt(df, NumericSpecies ~ ., type = "classification", seed = 12)
s <- summary(m)
# test numeric prediction values
expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction))
......@@ -106,7 +106,7 @@ test_that("spark.gbt", {
if (windows_with_hadoop()) {
data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"),
source = "libsvm")
model <- spark.gbt(data, label ~ features, "classification")
model <- spark.gbt(data, label ~ features, "classification", seed = 12)
expect_equal(summary(model)$numFeatures, 692)
}
......@@ -117,10 +117,11 @@ test_that("spark.gbt", {
trainidxs <- base::sample(nrow(data), nrow(data) * 0.7)
traindf <- as.DataFrame(data[trainidxs, ])
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
model <- spark.gbt(traindf, clicked ~ ., type = "classification")
model <- spark.gbt(traindf, clicked ~ ., type = "classification", seed = 23)
predictions <- predict(model, testdf)
expect_error(collect(predictions))
model <- spark.gbt(traindf, clicked ~ ., type = "classification", handleInvalid = "keep")
model <- spark.gbt(traindf, clicked ~ ., type = "classification", handleInvalid = "keep",
seed = 23)
predictions <- predict(model, testdf)
expect_equal(class(collect(predictions)$clicked[1]), "character")
})
......@@ -129,7 +130,7 @@ test_that("spark.randomForest", {
# regression
data <- suppressWarnings(createDataFrame(longley))
model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
numTrees = 1)
numTrees = 1, seed = 1)
predictions <- collect(predict(model, data))
expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
......@@ -177,7 +178,7 @@ test_that("spark.randomForest", {
# classification
data <- suppressWarnings(createDataFrame(iris))
model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification",
maxDepth = 5, maxBins = 16)
maxDepth = 5, maxBins = 16, seed = 123)
stats <- summary(model)
expect_equal(stats$numFeatures, 2)
......@@ -215,7 +216,7 @@ test_that("spark.randomForest", {
iris$NumericSpecies <- lapply(iris$Species, labelToIndex)
data <- suppressWarnings(createDataFrame(iris[-5]))
model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification",
maxDepth = 5, maxBins = 16)
maxDepth = 5, maxBins = 16, seed = 123)
stats <- summary(model)
expect_equal(stats$numFeatures, 2)
expect_equal(stats$numTrees, 20)
......@@ -234,12 +235,12 @@ test_that("spark.randomForest", {
traindf <- as.DataFrame(data[trainidxs, ])
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
maxDepth = 10, maxBins = 10, numTrees = 10)
maxDepth = 10, maxBins = 10, numTrees = 10, seed = 123)
predictions <- predict(model, testdf)
expect_error(collect(predictions))
model <- spark.randomForest(traindf, clicked ~ ., type = "classification",
maxDepth = 10, maxBins = 10, numTrees = 10,
handleInvalid = "keep")
handleInvalid = "keep", seed = 123)
predictions <- predict(model, testdf)
expect_equal(class(collect(predictions)$clicked[1]), "character")
......@@ -247,7 +248,7 @@ test_that("spark.randomForest", {
if (windows_with_hadoop()) {
data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
source = "libsvm")
model <- spark.randomForest(data, label ~ features, "classification")
model <- spark.randomForest(data, label ~ features, "classification", seed = 123)
expect_equal(summary(model)$numFeatures, 4)
}
})
......@@ -255,7 +256,8 @@ test_that("spark.randomForest", {
test_that("spark.decisionTree", {
# regression
data <- suppressWarnings(createDataFrame(longley))
model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16)
model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
seed = 42)
predictions <- collect(predict(model, data))
expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187,
......@@ -288,7 +290,7 @@ test_that("spark.decisionTree", {
# classification
data <- suppressWarnings(createDataFrame(iris))
model <- spark.decisionTree(data, Species ~ Petal_Length + Petal_Width, "classification",
maxDepth = 5, maxBins = 16)
maxDepth = 5, maxBins = 16, seed = 43)
stats <- summary(model)
expect_equal(stats$numFeatures, 2)
......@@ -325,7 +327,7 @@ test_that("spark.decisionTree", {
iris$NumericSpecies <- lapply(iris$Species, labelToIndex)
data <- suppressWarnings(createDataFrame(iris[-5]))
model <- spark.decisionTree(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification",
maxDepth = 5, maxBins = 16)
maxDepth = 5, maxBins = 16, seed = 44)
stats <- summary(model)
expect_equal(stats$numFeatures, 2)
expect_equal(stats$maxDepth, 5)
......@@ -339,7 +341,7 @@ test_that("spark.decisionTree", {
if (windows_with_hadoop()) {
data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
source = "libsvm")
model <- spark.decisionTree(data, label ~ features, "classification")
model <- spark.decisionTree(data, label ~ features, "classification", seed = 45)
expect_equal(summary(model)$numFeatures, 4)
}
......@@ -351,11 +353,11 @@ test_that("spark.decisionTree", {
traindf <- as.DataFrame(data[trainidxs, ])
testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other")))
model <- spark.decisionTree(traindf, clicked ~ ., type = "classification",
maxDepth = 5, maxBins = 16)
maxDepth = 5, maxBins = 16, seed = 46)
predictions <- predict(model, testdf)
expect_error(collect(predictions))
model <- spark.decisionTree(traindf, clicked ~ ., type = "classification",
maxDepth = 5, maxBins = 16, handleInvalid = "keep")
maxDepth = 5, maxBins = 16, handleInvalid = "keep", seed = 46)
predictions <- predict(model, testdf)
expect_equal(class(collect(predictions)$clicked[1]), "character")
})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment