Skip to content
Snippets Groups Projects
Commit 23405f32 authored by Yanbo Liang's avatar Yanbo Liang Committed by Joseph K. Bradley
Browse files

[SPARK-15153][ML][SPARKR] Fix SparkR spark.naiveBayes error when label is numeric type

## What changes were proposed in this pull request?
Fix SparkR ```spark.naiveBayes``` error when response variable of dataset is numeric type.
See details and how to reproduce this bug at [SPARK-15153](https://issues.apache.org/jira/browse/SPARK-15153).

## How was this patch tested?
Add unit test.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #15431 from yanboliang/spark-15153-2.
parent 07508bd0
No related branches found
No related tags found
No related merge requests found
......@@ -481,6 +481,16 @@ test_that("spark.naiveBayes", {
expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA)
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
}
# Test numeric response variable
t1$NumericSurvived <- ifelse(t1$Survived == "No", 0, 1)
t2 <- t1[-4]
df <- suppressWarnings(createDataFrame(t2))
m <- spark.naiveBayes(df, NumericSurvived ~ ., smoothing = 0.0)
s <- summary(m)
expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6)
expect_equal(sum(s$apriori), 1)
expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6)
})
test_that("spark.survreg", {
......
......@@ -59,6 +59,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
.setForceIndexLabel(true)
RWrapperUtils.checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment