Skip to content
Snippets Groups Projects
Commit ad09e4ca authored by Yanbo Liang's avatar Yanbo Liang
Browse files

[MINOR][SPARKR][ML] Joint coefficients with intercept for SparkR linear SVM summary.

## What changes were proposed in this pull request?
Joint coefficients with intercept for SparkR linear SVM summary.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #18035 from yanboliang/svm-r.
parent 442287ae
No related branches found
No related tags found
No related merge requests found
...@@ -46,15 +46,16 @@ setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj" ...@@ -46,15 +46,16 @@ setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj"
#' @note NaiveBayesModel since 2.0.0 #' @note NaiveBayesModel since 2.0.0
setClass("NaiveBayesModel", representation(jobj = "jobj")) setClass("NaiveBayesModel", representation(jobj = "jobj"))
#' linear SVM Model #' Linear SVM Model
#' #'
#' Fits an linear SVM model against a SparkDataFrame. It is a binary classifier, similar to svm in glmnet package #' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071 package.
#' Currently only supports binary classification model with linear kernel.
#' Users can print, make predictions on the produced model and save the model to the input path. #' Users can print, make predictions on the produced model and save the model to the input path.
#' #'
#' @param data SparkDataFrame for training. #' @param data SparkDataFrame for training.
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' @param formula A symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', '.', ':', '+', and '-'. #' operators are supported, including '~', '.', ':', '+', and '-'.
#' @param regParam The regularization parameter. #' @param regParam The regularization parameter. Only supports L2 regularization currently.
#' @param maxIter Maximum iteration number. #' @param maxIter Maximum iteration number.
#' @param tol Convergence tolerance of iterations. #' @param tol Convergence tolerance of iterations.
#' @param standardization Whether to standardize the training features before fitting the model. The coefficients #' @param standardization Whether to standardize the training features before fitting the model. The coefficients
...@@ -111,10 +112,10 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu ...@@ -111,10 +112,10 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu
new("LinearSVCModel", jobj = jobj) new("LinearSVCModel", jobj = jobj)
}) })
# Predicted values based on an LinearSVCModel model # Predicted values based on a LinearSVCModel model
#' @param newData a SparkDataFrame for testing. #' @param newData a SparkDataFrame for testing.
#' @return \code{predict} returns the predicted values based on an LinearSVCModel. #' @return \code{predict} returns the predicted values based on a LinearSVCModel.
#' @rdname spark.svmLinear #' @rdname spark.svmLinear
#' @aliases predict,LinearSVCModel,SparkDataFrame-method #' @aliases predict,LinearSVCModel,SparkDataFrame-method
#' @export #' @export
...@@ -124,13 +125,12 @@ setMethod("predict", signature(object = "LinearSVCModel"), ...@@ -124,13 +125,12 @@ setMethod("predict", signature(object = "LinearSVCModel"),
predict_internal(object, newData) predict_internal(object, newData)
}) })
# Get the summary of an LinearSVCModel # Get the summary of a LinearSVCModel
#' @param object an LinearSVCModel fitted by \code{spark.svmLinear}. #' @param object a LinearSVCModel fitted by \code{spark.svmLinear}.
#' @return \code{summary} returns summary information of the fitted model, which is a list. #' @return \code{summary} returns summary information of the fitted model, which is a list.
#' The list includes \code{coefficients} (coefficients of the fitted model), #' The list includes \code{coefficients} (coefficients of the fitted model),
#' \code{intercept} (intercept of the fitted model), \code{numClasses} (number of classes), #' \code{numClasses} (number of classes), \code{numFeatures} (number of features).
#' \code{numFeatures} (number of features).
#' @rdname spark.svmLinear #' @rdname spark.svmLinear
#' @aliases summary,LinearSVCModel-method #' @aliases summary,LinearSVCModel-method
#' @export #' @export
...@@ -138,22 +138,14 @@ setMethod("predict", signature(object = "LinearSVCModel"), ...@@ -138,22 +138,14 @@ setMethod("predict", signature(object = "LinearSVCModel"),
setMethod("summary", signature(object = "LinearSVCModel"), setMethod("summary", signature(object = "LinearSVCModel"),
function(object) { function(object) {
jobj <- object@jobj jobj <- object@jobj
features <- callJMethod(jobj, "features") features <- callJMethod(jobj, "rFeatures")
labels <- callJMethod(jobj, "labels") coefficients <- callJMethod(jobj, "rCoefficients")
coefficients <- callJMethod(jobj, "coefficients") coefficients <- as.matrix(unlist(coefficients))
nCol <- length(coefficients) / length(features) colnames(coefficients) <- c("Estimate")
coefficients <- matrix(unlist(coefficients), ncol = nCol) rownames(coefficients) <- unlist(features)
intercept <- callJMethod(jobj, "intercept")
numClasses <- callJMethod(jobj, "numClasses") numClasses <- callJMethod(jobj, "numClasses")
numFeatures <- callJMethod(jobj, "numFeatures") numFeatures <- callJMethod(jobj, "numFeatures")
if (nCol == 1) { list(coefficients = coefficients, numClasses = numClasses, numFeatures = numFeatures)
colnames(coefficients) <- c("Estimate")
} else {
colnames(coefficients) <- unlist(labels)
}
rownames(coefficients) <- unlist(features)
list(coefficients = coefficients, intercept = intercept,
numClasses = numClasses, numFeatures = numFeatures)
}) })
# Save fitted LinearSVCModel to the input path # Save fitted LinearSVCModel to the input path
......
...@@ -38,9 +38,8 @@ test_that("spark.svmLinear", { ...@@ -38,9 +38,8 @@ test_that("spark.svmLinear", {
expect_true(class(summary$coefficients[, 1]) == "numeric") expect_true(class(summary$coefficients[, 1]) == "numeric")
coefs <- summary$coefficients[, "Estimate"] coefs <- summary$coefficients[, "Estimate"]
expected_coefs <- c(-0.1563083, -0.460648, 0.2276626, 1.055085) expected_coefs <- c(-0.06004978, -0.1563083, -0.460648, 0.2276626, 1.055085)
expect_true(all(abs(coefs - expected_coefs) < 0.1)) expect_true(all(abs(coefs - expected_coefs) < 0.1))
expect_equal(summary$intercept, -0.06004978, tolerance = 1e-2)
# Test prediction with string label # Test prediction with string label
prediction <- predict(model, training) prediction <- predict(model, training)
......
...@@ -38,9 +38,17 @@ private[r] class LinearSVCWrapper private ( ...@@ -38,9 +38,17 @@ private[r] class LinearSVCWrapper private (
private val svcModel: LinearSVCModel = private val svcModel: LinearSVCModel =
pipeline.stages(1).asInstanceOf[LinearSVCModel] pipeline.stages(1).asInstanceOf[LinearSVCModel]
lazy val coefficients: Array[Double] = svcModel.coefficients.toArray lazy val rFeatures: Array[String] = if (svcModel.getFitIntercept) {
Array("(Intercept)") ++ features
} else {
features
}
lazy val intercept: Double = svcModel.intercept lazy val rCoefficients: Array[Double] = if (svcModel.getFitIntercept) {
Array(svcModel.intercept) ++ svcModel.coefficients.toArray
} else {
svcModel.coefficients.toArray
}
lazy val numClasses: Int = svcModel.numClasses lazy val numClasses: Int = svcModel.numClasses
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment