Skip to content
Snippets Groups Projects
Commit 0583ecda authored by Felix Cheung's avatar Felix Cheung Committed by Felix Cheung
Browse files

[SPARK-17173][SPARKR] R MLlib refactor, cleanup, reformat, fix deprecation in test

## What changes were proposed in this pull request?

refactor, cleanup, reformat, fix deprecation in test

## How was this patch tested?

unit tests, manual tests

Author: Felix Cheung <felixcheung_m@hotmail.com>

Closes #14735 from felixcheung/rmllibutil.
parent 342278c0
No related branches found
No related tags found
No related merge requests found
......@@ -88,9 +88,9 @@ setClass("ALSModel", representation(jobj = "jobj"))
#' @rdname write.ml
#' @name write.ml
#' @export
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
#' @seealso \link{spark.als}, \link{spark.kmeans}, \link{spark.lda}, \link{spark.naiveBayes}
#' @seealso \link{spark.survreg}, \link{spark.isoreg}
#' @seealso \link{spark.glm}, \link{glm},
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
#' @seealso \link{spark.lda}, \link{spark.naiveBayes}, \link{spark.survreg},
#' @seealso \link{read.ml}
NULL
......@@ -101,11 +101,22 @@ NULL
#' @rdname predict
#' @name predict
#' @export
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
#' @seealso \link{spark.als}, \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
#' @seealso \link{spark.isoreg}
#' @seealso \link{spark.glm}, \link{glm},
#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans},
#' @seealso \link{spark.naiveBayes}, \link{spark.survreg},
NULL
write_internal <- function(object, path, overwrite = FALSE) {
writer <- callJMethod(object@jobj, "write")
if (overwrite) {
writer <- callJMethod(writer, "overwrite")
}
invisible(callJMethod(writer, "save", path))
}
predict_internal <- function(object, newData) {
dataFrame(callJMethod(object@jobj, "transform", newData@sdf))
}
#' Generalized Linear Models
#'
......@@ -173,7 +184,7 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, family$family, family$link,
tol, as.integer(maxIter), as.character(weightCol))
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
new("GeneralizedLinearRegressionModel", jobj = jobj)
})
#' Generalized Linear Models (R-compliant)
......@@ -219,7 +230,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat
#' @export
#' @note summary(GeneralizedLinearRegressionModel) since 2.0.0
setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
function(object, ...) {
function(object) {
jobj <- object@jobj
is.loaded <- callJMethod(jobj, "isLoaded")
features <- callJMethod(jobj, "rFeatures")
......@@ -245,7 +256,7 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"),
deviance = deviance, df.null = df.null, df.residual = df.residual,
aic = aic, iter = iter, family = family, is.loaded = is.loaded)
class(ans) <- "summary.GeneralizedLinearRegressionModel"
return(ans)
ans
})
# Prints the summary of GeneralizedLinearRegressionModel
......@@ -275,8 +286,7 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
" on", format(unlist(x[c("df.null", "df.residual")])), " degrees of freedom\n"),
1L, paste, collapse = " "), sep = "")
cat("AIC: ", format(x$aic, digits = 4L), "\n\n",
"Number of Fisher Scoring iterations: ", x$iter, "\n", sep = "")
cat("\n")
"Number of Fisher Scoring iterations: ", x$iter, "\n\n", sep = "")
invisible(x)
}
......@@ -291,7 +301,7 @@ print.summary.GeneralizedLinearRegressionModel <- function(x, ...) {
#' @note predict(GeneralizedLinearRegressionModel) since 1.5.0
setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
predict_internal(object, newData)
})
# Makes predictions from a naive Bayes model or a model produced by spark.naiveBayes(),
......@@ -305,7 +315,7 @@ setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"),
#' @note predict(NaiveBayesModel) since 2.0.0
setMethod("predict", signature(object = "NaiveBayesModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
predict_internal(object, newData)
})
# Returns the summary of a naive Bayes model produced by \code{spark.naiveBayes}
......@@ -317,7 +327,7 @@ setMethod("predict", signature(object = "NaiveBayesModel"),
#' @export
#' @note summary(NaiveBayesModel) since 2.0.0
setMethod("summary", signature(object = "NaiveBayesModel"),
function(object, ...) {
function(object) {
jobj <- object@jobj
features <- callJMethod(jobj, "features")
labels <- callJMethod(jobj, "labels")
......@@ -328,7 +338,7 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
tables <- matrix(tables, nrow = length(labels))
rownames(tables) <- unlist(labels)
colnames(tables) <- unlist(features)
return(list(apriori = apriori, tables = tables))
list(apriori = apriori, tables = tables)
})
# Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda()
......@@ -342,7 +352,7 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
#' @note spark.posterior(LDAModel) since 2.1.0
setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
predict_internal(object, newData)
})
# Returns the summary of a Latent Dirichlet Allocation model produced by \code{spark.lda}
......@@ -377,12 +387,11 @@ setMethod("summary", signature(object = "LDAModel"),
vocabSize <- callJMethod(jobj, "vocabSize")
topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic))
vocabulary <- callJMethod(jobj, "vocabulary")
return(list(docConcentration = unlist(docConcentration),
topicConcentration = topicConcentration,
logLikelihood = logLikelihood, logPerplexity = logPerplexity,
isDistributed = isDistributed, vocabSize = vocabSize,
topics = topics,
vocabulary = unlist(vocabulary)))
list(docConcentration = unlist(docConcentration),
topicConcentration = topicConcentration,
logLikelihood = logLikelihood, logPerplexity = logPerplexity,
isDistributed = isDistributed, vocabSize = vocabSize,
topics = topics, vocabulary = unlist(vocabulary))
})
# Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda}
......@@ -395,8 +404,8 @@ setMethod("summary", signature(object = "LDAModel"),
#' @note spark.perplexity(LDAModel) since 2.1.0
setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"),
function(object, data) {
return(ifelse(missing(data), callJMethod(object@jobj, "logPerplexity"),
callJMethod(object@jobj, "computeLogPerplexity", data@sdf)))
ifelse(missing(data), callJMethod(object@jobj, "logPerplexity"),
callJMethod(object@jobj, "computeLogPerplexity", data@sdf))
})
# Saves the Latent Dirichlet Allocation model to the input path.
......@@ -412,11 +421,7 @@ setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFr
#' @note write.ml(LDAModel, character) since 2.1.0
setMethod("write.ml", signature(object = "LDAModel", path = "character"),
function(object, path, overwrite = FALSE) {
writer <- callJMethod(object@jobj, "write")
if (overwrite) {
writer <- callJMethod(writer, "overwrite")
}
invisible(callJMethod(writer, "save", path))
write_internal(object, path, overwrite)
})
#' Isotonic Regression Model
......@@ -471,9 +476,9 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"
}
jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit",
data@sdf, formula, as.logical(isotonic), as.integer(featureIndex),
as.character(weightCol))
return(new("IsotonicRegressionModel", jobj = jobj))
data@sdf, formula, as.logical(isotonic), as.integer(featureIndex),
as.character(weightCol))
new("IsotonicRegressionModel", jobj = jobj)
})
# Predicted values based on an isotonicRegression model
......@@ -487,7 +492,7 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"
#' @note predict(IsotonicRegressionModel) since 2.1.0
setMethod("predict", signature(object = "IsotonicRegressionModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
predict_internal(object, newData)
})
# Get the summary of an IsotonicRegressionModel model
......@@ -499,11 +504,11 @@ setMethod("predict", signature(object = "IsotonicRegressionModel"),
#' @export
#' @note summary(IsotonicRegressionModel) since 2.1.0
setMethod("summary", signature(object = "IsotonicRegressionModel"),
function(object, ...) {
function(object) {
jobj <- object@jobj
boundaries <- callJMethod(jobj, "boundaries")
predictions <- callJMethod(jobj, "predictions")
return(list(boundaries = boundaries, predictions = predictions))
list(boundaries = boundaries, predictions = predictions)
})
#' K-Means Clustering Model
......@@ -553,7 +558,7 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"
initMode <- match.arg(initMode)
jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf, formula,
as.integer(k), as.integer(maxIter), initMode)
return(new("KMeansModel", jobj = jobj))
new("KMeansModel", jobj = jobj)
})
#' Get fitted result from a k-means model
......@@ -576,14 +581,14 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"
#'}
#' @note fitted since 2.0.0
setMethod("fitted", signature(object = "KMeansModel"),
function(object, method = c("centers", "classes"), ...) {
function(object, method = c("centers", "classes")) {
method <- match.arg(method)
jobj <- object@jobj
is.loaded <- callJMethod(jobj, "isLoaded")
if (is.loaded) {
stop(paste("Saved-loaded k-means model does not support 'fitted' method"))
stop("Saved-loaded k-means model does not support 'fitted' method")
} else {
return(dataFrame(callJMethod(jobj, "fitted", method)))
dataFrame(callJMethod(jobj, "fitted", method))
}
})
......@@ -595,7 +600,7 @@ setMethod("fitted", signature(object = "KMeansModel"),
#' @export
#' @note summary(KMeansModel) since 2.0.0
setMethod("summary", signature(object = "KMeansModel"),
function(object, ...) {
function(object) {
jobj <- object@jobj
is.loaded <- callJMethod(jobj, "isLoaded")
features <- callJMethod(jobj, "features")
......@@ -610,8 +615,8 @@ setMethod("summary", signature(object = "KMeansModel"),
} else {
dataFrame(callJMethod(jobj, "cluster"))
}
return(list(coefficients = coefficients, size = size,
cluster = cluster, is.loaded = is.loaded))
list(coefficients = coefficients, size = size,
cluster = cluster, is.loaded = is.loaded)
})
# Predicted values based on a k-means model
......@@ -623,7 +628,7 @@ setMethod("summary", signature(object = "KMeansModel"),
#' @note predict(KMeansModel) since 2.0.0
setMethod("predict", signature(object = "KMeansModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
predict_internal(object, newData)
})
#' Naive Bayes Models
......@@ -665,11 +670,11 @@ setMethod("predict", signature(object = "KMeansModel"),
#' }
#' @note spark.naiveBayes since 2.0.0
setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, smoothing = 1.0, ...) {
function(data, formula, smoothing = 1.0) {
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit",
formula, data@sdf, smoothing)
return(new("NaiveBayesModel", jobj = jobj))
new("NaiveBayesModel", jobj = jobj)
})
# Saves the Bernoulli naive Bayes model to the input path.
......@@ -684,11 +689,7 @@ setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "form
#' @note write.ml(NaiveBayesModel, character) since 2.0.0
setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"),
function(object, path, overwrite = FALSE) {
writer <- callJMethod(object@jobj, "write")
if (overwrite) {
writer <- callJMethod(writer, "overwrite")
}
invisible(callJMethod(writer, "save", path))
write_internal(object, path, overwrite)
})
# Saves the AFT survival regression model to the input path.
......@@ -702,11 +703,7 @@ setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"),
#' @seealso \link{read.ml}
setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"),
function(object, path, overwrite = FALSE) {
writer <- callJMethod(object@jobj, "write")
if (overwrite) {
writer <- callJMethod(writer, "overwrite")
}
invisible(callJMethod(writer, "save", path))
write_internal(object, path, overwrite)
})
# Saves the generalized linear model to the input path.
......@@ -720,11 +717,7 @@ setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "c
#' @note write.ml(GeneralizedLinearRegressionModel, character) since 2.0.0
setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", path = "character"),
function(object, path, overwrite = FALSE) {
writer <- callJMethod(object@jobj, "write")
if (overwrite) {
writer <- callJMethod(writer, "overwrite")
}
invisible(callJMethod(writer, "save", path))
write_internal(object, path, overwrite)
})
# Save fitted MLlib model to the input path
......@@ -738,11 +731,7 @@ setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", pat
#' @note write.ml(KMeansModel, character) since 2.0.0
setMethod("write.ml", signature(object = "KMeansModel", path = "character"),
function(object, path, overwrite = FALSE) {
writer <- callJMethod(object@jobj, "write")
if (overwrite) {
writer <- callJMethod(writer, "overwrite")
}
invisible(callJMethod(writer, "save", path))
write_internal(object, path, overwrite)
})
# Save fitted IsotonicRegressionModel to the input path
......@@ -757,11 +746,7 @@ setMethod("write.ml", signature(object = "KMeansModel", path = "character"),
#' @note write.ml(IsotonicRegression, character) since 2.1.0
setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "character"),
function(object, path, overwrite = FALSE) {
writer <- callJMethod(object@jobj, "write")
if (overwrite) {
writer <- callJMethod(writer, "overwrite")
}
invisible(callJMethod(writer, "save", path))
write_internal(object, path, overwrite)
})
# Save fitted MLlib model to the input path
......@@ -776,11 +761,7 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
#' @note write.ml(GaussianMixtureModel, character) since 2.1.0
setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"),
function(object, path, overwrite = FALSE) {
writer <- callJMethod(object@jobj, "write")
if (overwrite) {
writer <- callJMethod(writer, "overwrite")
}
invisible(callJMethod(writer, "save", path))
write_internal(object, path, overwrite)
})
#' Load a fitted MLlib model from the input path.
......@@ -801,21 +782,21 @@ read.ml <- function(path) {
path <- suppressWarnings(normalizePath(path))
jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path)
if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) {
return(new("NaiveBayesModel", jobj = jobj))
new("NaiveBayesModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) {
return(new("AFTSurvivalRegressionModel", jobj = jobj))
new("AFTSurvivalRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper")) {
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
new("GeneralizedLinearRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
return(new("KMeansModel", jobj = jobj))
new("KMeansModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) {
return(new("LDAModel", jobj = jobj))
new("LDAModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
return(new("IsotonicRegressionModel", jobj = jobj))
new("IsotonicRegressionModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
return(new("GaussianMixtureModel", jobj = jobj))
new("GaussianMixtureModel", jobj = jobj)
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) {
return(new("ALSModel", jobj = jobj))
new("ALSModel", jobj = jobj)
} else {
stop(paste("Unsupported model: ", jobj))
}
......@@ -860,7 +841,7 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
"fit", formula, data@sdf)
return(new("AFTSurvivalRegressionModel", jobj = jobj))
new("AFTSurvivalRegressionModel", jobj = jobj)
})
#' Latent Dirichlet Allocation
......@@ -926,7 +907,7 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"),
as.numeric(subsamplingRate), topicConcentration,
as.array(docConcentration), as.array(customizedStopWords),
maxVocabSize)
return(new("LDAModel", jobj = jobj))
new("LDAModel", jobj = jobj)
})
# Returns a summary of the AFT survival regression model produced by spark.survreg,
......@@ -946,7 +927,7 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Value")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
list(coefficients = coefficients)
})
# Makes predictions from an AFT survival regression model or a model produced by
......@@ -960,7 +941,7 @@ setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
#' @note predict(AFTSurvivalRegressionModel) since 2.0.0
setMethod("predict", signature(object = "AFTSurvivalRegressionModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
predict_internal(object, newData)
})
#' Multivariate Gaussian Mixture Model (GMM)
......@@ -1014,7 +995,7 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula =
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.GaussianMixtureWrapper", "fit", data@sdf,
formula, as.integer(k), as.integer(maxIter), as.numeric(tol))
return(new("GaussianMixtureModel", jobj = jobj))
new("GaussianMixtureModel", jobj = jobj)
})
# Get the summary of a multivariate gaussian mixture model
......@@ -1027,7 +1008,7 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula =
#' @export
#' @note summary(GaussianMixtureModel) since 2.1.0
setMethod("summary", signature(object = "GaussianMixtureModel"),
function(object, ...) {
function(object) {
jobj <- object@jobj
is.loaded <- callJMethod(jobj, "isLoaded")
lambda <- unlist(callJMethod(jobj, "lambda"))
......@@ -1052,8 +1033,8 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
} else {
dataFrame(callJMethod(jobj, "posterior"))
}
return(list(lambda = lambda, mu = mu, sigma = sigma,
posterior = posterior, is.loaded = is.loaded))
list(lambda = lambda, mu = mu, sigma = sigma,
posterior = posterior, is.loaded = is.loaded)
})
# Predicted values based on a gaussian mixture model
......@@ -1067,7 +1048,7 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
#' @note predict(GaussianMixtureModel) since 2.1.0
setMethod("predict", signature(object = "GaussianMixtureModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
predict_internal(object, newData)
})
#' Alternating Least Squares (ALS) for Collaborative Filtering
......@@ -1149,7 +1130,7 @@ setMethod("spark.als", signature(data = "SparkDataFrame"),
reg, as.integer(maxIter), implicitPrefs, alpha, nonnegative,
as.integer(numUserBlocks), as.integer(numItemBlocks),
as.integer(checkpointInterval), as.integer(seed))
return(new("ALSModel", jobj = jobj))
new("ALSModel", jobj = jobj)
})
# Returns a summary of the ALS model produced by spark.als.
......@@ -1163,17 +1144,17 @@ setMethod("spark.als", signature(data = "SparkDataFrame"),
#' @export
#' @note summary(ALSModel) since 2.1.0
setMethod("summary", signature(object = "ALSModel"),
function(object, ...) {
jobj <- object@jobj
user <- callJMethod(jobj, "userCol")
item <- callJMethod(jobj, "itemCol")
rating <- callJMethod(jobj, "ratingCol")
userFactors <- dataFrame(callJMethod(jobj, "userFactors"))
itemFactors <- dataFrame(callJMethod(jobj, "itemFactors"))
rank <- callJMethod(jobj, "rank")
return(list(user = user, item = item, rating = rating, userFactors = userFactors,
itemFactors = itemFactors, rank = rank))
})
function(object) {
jobj <- object@jobj
user <- callJMethod(jobj, "userCol")
item <- callJMethod(jobj, "itemCol")
rating <- callJMethod(jobj, "ratingCol")
userFactors <- dataFrame(callJMethod(jobj, "userFactors"))
itemFactors <- dataFrame(callJMethod(jobj, "itemFactors"))
rank <- callJMethod(jobj, "rank")
list(user = user, item = item, rating = rating, userFactors = userFactors,
itemFactors = itemFactors, rank = rank)
})
# Makes predictions from an ALS model or a model produced by spark.als.
......@@ -1185,9 +1166,9 @@ function(object, ...) {
#' @export
#' @note predict(ALSModel) since 2.1.0
setMethod("predict", signature(object = "ALSModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})
function(object, newData) {
predict_internal(object, newData)
})
# Saves the ALS model to the input path.
......@@ -1203,10 +1184,6 @@ function(object, newData) {
#' @seealso \link{read.ml}
#' @note write.ml(ALSModel, character) since 2.1.0
setMethod("write.ml", signature(object = "ALSModel", path = "character"),
function(object, path, overwrite = FALSE) {
writer <- callJMethod(object@jobj, "write")
if (overwrite) {
writer <- callJMethod(writer, "overwrite")
}
invisible(callJMethod(writer, "save", path))
})
function(object, path, overwrite = FALSE) {
write_internal(object, path, overwrite)
})
......@@ -95,6 +95,10 @@ test_that("spark.glm summary", {
expect_equal(stats$df.residual, rStats$df.residual)
expect_equal(stats$aic, rStats$aic)
out <- capture.output(print(stats))
expect_match(out[2], "Deviance Residuals:")
expect_true(any(grepl("AIC: 59.22", out)))
# binomial family
df <- suppressWarnings(createDataFrame(iris))
training <- df[df$Species %in% c("versicolor", "virginica"), ]
......@@ -409,7 +413,7 @@ test_that("spark.naiveBayes", {
# Test e1071::naiveBayes
if (requireNamespace("e1071", quietly = TRUE)) {
expect_that(m <- e1071::naiveBayes(Survived ~ ., data = t1), not(throws_error()))
expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA)
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
}
})
......@@ -487,7 +491,7 @@ test_that("spark.isotonicRegression", {
weightCol = "weight")
# only allow one variable on the right hand side of the formula
expect_error(model2 <- spark.isoreg(df, ~., isotonic = FALSE))
result <- summary(model, df)
result <- summary(model)
expect_equal(result$predictions, list(7, 5, 4, 4, 1))
# Test model prediction
......@@ -503,7 +507,7 @@ test_that("spark.isotonicRegression", {
expect_error(write.ml(model, modelPath))
write.ml(model, modelPath, overwrite = TRUE)
model2 <- read.ml(modelPath)
expect_equal(result, summary(model2, df))
expect_equal(result, summary(model2))
unlink(modelPath)
})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment