Skip to content
Snippets Groups Projects
Commit 07be232e authored by Yanbo Liang's avatar Yanbo Liang
Browse files

[SPARK-18412][SPARKR][ML] Fix exception for some SparkR ML algorithms training on libsvm data

## What changes were proposed in this pull request?
* Fix the following exceptions which throws when ```spark.randomForest```(classification), ```spark.gbt```(classification), ```spark.naiveBayes``` and ```spark.glm```(binomial family) were fitted on libsvm data.
```
java.lang.IllegalArgumentException: requirement failed: If label column already exists, forceIndexLabel can not be set with true.
```
See [SPARK-18412](https://issues.apache.org/jira/browse/SPARK-18412) for more detail about how to reproduce this bug.
* Refactor out ```getFeaturesAndLabels``` to RWrapperUtils, since lots of ML algorithm wrappers use this function.
* Drop some unwanted columns when making prediction.

## How was this patch tested?
Add unit test.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #15851 from yanboliang/spark-18412.
parent b91a51bb
No related branches found
No related tags found
No related merge requests found
...@@ -881,7 +881,8 @@ test_that("spark.kstest", { ...@@ -881,7 +881,8 @@ test_that("spark.kstest", {
expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:")
}) })
test_that("spark.randomForest Regression", { test_that("spark.randomForest", {
# regression
data <- suppressWarnings(createDataFrame(longley)) data <- suppressWarnings(createDataFrame(longley))
model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
numTrees = 1) numTrees = 1)
...@@ -923,9 +924,8 @@ test_that("spark.randomForest Regression", { ...@@ -923,9 +924,8 @@ test_that("spark.randomForest Regression", {
expect_equal(stats$treeWeights, stats2$treeWeights) expect_equal(stats$treeWeights, stats2$treeWeights)
unlink(modelPath) unlink(modelPath)
})
test_that("spark.randomForest Classification", { # classification
data <- suppressWarnings(createDataFrame(iris)) data <- suppressWarnings(createDataFrame(iris))
model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification",
maxDepth = 5, maxBins = 16) maxDepth = 5, maxBins = 16)
...@@ -971,6 +971,12 @@ test_that("spark.randomForest Classification", { ...@@ -971,6 +971,12 @@ test_that("spark.randomForest Classification", {
predictions <- collect(predict(model, data))$prediction predictions <- collect(predict(model, data))$prediction
expect_equal(length(grep("1.0", predictions)), 50) expect_equal(length(grep("1.0", predictions)), 50)
expect_equal(length(grep("2.0", predictions)), 50) expect_equal(length(grep("2.0", predictions)), 50)
# spark.randomForest classification can work on libsvm data
data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"),
source = "libsvm")
model <- spark.randomForest(data, label ~ features, "classification")
expect_equal(summary(model)$numFeatures, 4)
}) })
test_that("spark.gbt", { test_that("spark.gbt", {
...@@ -1039,6 +1045,12 @@ test_that("spark.gbt", { ...@@ -1039,6 +1045,12 @@ test_that("spark.gbt", {
expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction))
expect_equal(s$numFeatures, 5) expect_equal(s$numFeatures, 5)
expect_equal(s$numTrees, 20) expect_equal(s$numTrees, 20)
# spark.gbt classification can work on libsvm data
data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"),
source = "libsvm")
model <- spark.gbt(data, label ~ features, "classification")
expect_equal(summary(model)$numFeatures, 692)
}) })
sparkR.session.stop() sparkR.session.stop()
...@@ -23,10 +23,10 @@ import org.json4s.JsonDSL._ ...@@ -23,10 +23,10 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.JsonMethods._
import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.feature.{IndexToString, RFormula}
import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.r.RWrapperUtils._
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.{DataFrame, Dataset}
...@@ -51,6 +51,7 @@ private[r] class GBTClassifierWrapper private ( ...@@ -51,6 +51,7 @@ private[r] class GBTClassifierWrapper private (
pipeline.transform(dataset) pipeline.transform(dataset)
.drop(PREDICTED_LABEL_INDEX_COL) .drop(PREDICTED_LABEL_INDEX_COL)
.drop(gbtcModel.getFeaturesCol) .drop(gbtcModel.getFeaturesCol)
.drop(gbtcModel.getLabelCol)
} }
override def write: MLWriter = new override def write: MLWriter = new
...@@ -81,19 +82,11 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] ...@@ -81,19 +82,11 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper]
val rFormula = new RFormula() val rFormula = new RFormula()
.setFormula(formula) .setFormula(formula)
.setForceIndexLabel(true) .setForceIndexLabel(true)
RWrapperUtils.checkDataColumns(rFormula, data) checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data) val rFormulaModel = rFormula.fit(data)
// get feature names from output schema // get labels and feature names from output schema
val schema = rFormulaModel.transform(data).schema val (features, labels) = getFeaturesAndLabels(rFormulaModel, data)
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
.attributes.get
val features = featureAttrs.map(_.name.get)
// get label names from output schema
val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
.asInstanceOf[NominalAttribute]
val labels = labelAttr.values.get
// assemble and fit the pipeline // assemble and fit the pipeline
val rfc = new GBTClassifier() val rfc = new GBTClassifier()
...@@ -109,6 +102,7 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] ...@@ -109,6 +102,7 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper]
.setMaxMemoryInMB(maxMemoryInMB) .setMaxMemoryInMB(maxMemoryInMB)
.setCacheNodeIds(cacheNodeIds) .setCacheNodeIds(cacheNodeIds)
.setFeaturesCol(rFormula.getFeaturesCol) .setFeaturesCol(rFormula.getFeaturesCol)
.setLabelCol(rFormula.getLabelCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL) .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)
......
...@@ -29,6 +29,7 @@ import org.apache.spark.ml.regression._ ...@@ -29,6 +29,7 @@ import org.apache.spark.ml.regression._
import org.apache.spark.ml.Transformer import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.r.RWrapperUtils._
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
...@@ -64,6 +65,7 @@ private[r] class GeneralizedLinearRegressionWrapper private ( ...@@ -64,6 +65,7 @@ private[r] class GeneralizedLinearRegressionWrapper private (
.drop(PREDICTED_LABEL_PROB_COL) .drop(PREDICTED_LABEL_PROB_COL)
.drop(PREDICTED_LABEL_INDEX_COL) .drop(PREDICTED_LABEL_INDEX_COL)
.drop(glm.getFeaturesCol) .drop(glm.getFeaturesCol)
.drop(glm.getLabelCol)
} else { } else {
pipeline.transform(dataset) pipeline.transform(dataset)
.drop(glm.getFeaturesCol) .drop(glm.getFeaturesCol)
...@@ -92,7 +94,7 @@ private[r] object GeneralizedLinearRegressionWrapper ...@@ -92,7 +94,7 @@ private[r] object GeneralizedLinearRegressionWrapper
regParam: Double): GeneralizedLinearRegressionWrapper = { regParam: Double): GeneralizedLinearRegressionWrapper = {
val rFormula = new RFormula().setFormula(formula) val rFormula = new RFormula().setFormula(formula)
if (family == "binomial") rFormula.setForceIndexLabel(true) if (family == "binomial") rFormula.setForceIndexLabel(true)
RWrapperUtils.checkDataColumns(rFormula, data) checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data) val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema // get labels and feature names from output schema
val schema = rFormulaModel.transform(data).schema val schema = rFormulaModel.transform(data).schema
...@@ -109,6 +111,7 @@ private[r] object GeneralizedLinearRegressionWrapper ...@@ -109,6 +111,7 @@ private[r] object GeneralizedLinearRegressionWrapper
.setWeightCol(weightCol) .setWeightCol(weightCol)
.setRegParam(regParam) .setRegParam(regParam)
.setFeaturesCol(rFormula.getFeaturesCol) .setFeaturesCol(rFormula.getFeaturesCol)
.setLabelCol(rFormula.getLabelCol)
val pipeline = if (family == "binomial") { val pipeline = if (family == "binomial") {
// Convert prediction from probability to label index. // Convert prediction from probability to label index.
val probToPred = new ProbabilityToPrediction() val probToPred = new ProbabilityToPrediction()
......
...@@ -23,9 +23,9 @@ import org.json4s.JsonDSL._ ...@@ -23,9 +23,9 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.JsonMethods._
import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.feature.{IndexToString, RFormula}
import org.apache.spark.ml.r.RWrapperUtils._
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.{DataFrame, Dataset}
...@@ -46,6 +46,7 @@ private[r] class NaiveBayesWrapper private ( ...@@ -46,6 +46,7 @@ private[r] class NaiveBayesWrapper private (
pipeline.transform(dataset) pipeline.transform(dataset)
.drop(PREDICTED_LABEL_INDEX_COL) .drop(PREDICTED_LABEL_INDEX_COL)
.drop(naiveBayesModel.getFeaturesCol) .drop(naiveBayesModel.getFeaturesCol)
.drop(naiveBayesModel.getLabelCol)
} }
override def write: MLWriter = new NaiveBayesWrapper.NaiveBayesWrapperWriter(this) override def write: MLWriter = new NaiveBayesWrapper.NaiveBayesWrapperWriter(this)
...@@ -60,21 +61,16 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { ...@@ -60,21 +61,16 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
val rFormula = new RFormula() val rFormula = new RFormula()
.setFormula(formula) .setFormula(formula)
.setForceIndexLabel(true) .setForceIndexLabel(true)
RWrapperUtils.checkDataColumns(rFormula, data) checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data) val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema // get labels and feature names from output schema
val schema = rFormulaModel.transform(data).schema val (features, labels) = getFeaturesAndLabels(rFormulaModel, data)
val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
.asInstanceOf[NominalAttribute]
val labels = labelAttr.values.get
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
.attributes.get
val features = featureAttrs.map(_.name.get)
// assemble and fit the pipeline // assemble and fit the pipeline
val naiveBayes = new NaiveBayes() val naiveBayes = new NaiveBayes()
.setSmoothing(smoothing) .setSmoothing(smoothing)
.setModelType("bernoulli") .setModelType("bernoulli")
.setFeaturesCol(rFormula.getFeaturesCol) .setFeaturesCol(rFormula.getFeaturesCol)
.setLabelCol(rFormula.getLabelCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL) .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
val idxToStr = new IndexToString() val idxToStr = new IndexToString()
.setInputCol(PREDICTED_LABEL_INDEX_COL) .setInputCol(PREDICTED_LABEL_INDEX_COL)
......
...@@ -18,11 +18,12 @@ ...@@ -18,11 +18,12 @@
package org.apache.spark.ml.r package org.apache.spark.ml.r
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.feature.{RFormula, RFormulaModel}
import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.Dataset import org.apache.spark.sql.Dataset
object RWrapperUtils extends Logging { private[r] object RWrapperUtils extends Logging {
/** /**
* DataFrame column check. * DataFrame column check.
...@@ -32,14 +33,41 @@ object RWrapperUtils extends Logging { ...@@ -32,14 +33,41 @@ object RWrapperUtils extends Logging {
* *
* @param rFormula RFormula instance * @param rFormula RFormula instance
* @param data Input dataset * @param data Input dataset
* @return Unit
*/ */
def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = { def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = {
if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) { if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) {
val newFeaturesName = s"${Identifiable.randomUID(rFormula.getFeaturesCol)}" val newFeaturesName = s"${Identifiable.randomUID(rFormula.getFeaturesCol)}"
logWarning(s"data containing ${rFormula.getFeaturesCol} column, " + logInfo(s"data containing ${rFormula.getFeaturesCol} column, " +
s"using new name $newFeaturesName instead") s"using new name $newFeaturesName instead")
rFormula.setFeaturesCol(newFeaturesName) rFormula.setFeaturesCol(newFeaturesName)
} }
if (rFormula.getForceIndexLabel && data.schema.fieldNames.contains(rFormula.getLabelCol)) {
val newLabelName = s"${Identifiable.randomUID(rFormula.getLabelCol)}"
logInfo(s"data containing ${rFormula.getLabelCol} column and we force to index label, " +
s"using new name $newLabelName instead")
rFormula.setLabelCol(newLabelName)
}
}
/**
* Get the feature names and original labels from the schema
* of DataFrame transformed by RFormulaModel.
*
* @param rFormulaModel The RFormulaModel instance.
* @param data Input dataset.
* @return The feature names and original labels.
*/
def getFeaturesAndLabels(
rFormulaModel: RFormulaModel,
data: Dataset[_]): (Array[String], Array[String]) = {
val schema = rFormulaModel.transform(data).schema
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
.attributes.get
val features = featureAttrs.map(_.name.get)
val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
.asInstanceOf[NominalAttribute]
val labels = labelAttr.values.get
(features, labels)
} }
} }
...@@ -23,10 +23,10 @@ import org.json4s.JsonDSL._ ...@@ -23,10 +23,10 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.JsonMethods._
import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.feature.{IndexToString, RFormula}
import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.r.RWrapperUtils._
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.{DataFrame, Dataset}
...@@ -51,6 +51,7 @@ private[r] class RandomForestClassifierWrapper private ( ...@@ -51,6 +51,7 @@ private[r] class RandomForestClassifierWrapper private (
pipeline.transform(dataset) pipeline.transform(dataset)
.drop(PREDICTED_LABEL_INDEX_COL) .drop(PREDICTED_LABEL_INDEX_COL)
.drop(rfcModel.getFeaturesCol) .drop(rfcModel.getFeaturesCol)
.drop(rfcModel.getLabelCol)
} }
override def write: MLWriter = new override def write: MLWriter = new
...@@ -82,19 +83,11 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC ...@@ -82,19 +83,11 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
val rFormula = new RFormula() val rFormula = new RFormula()
.setFormula(formula) .setFormula(formula)
.setForceIndexLabel(true) .setForceIndexLabel(true)
RWrapperUtils.checkDataColumns(rFormula, data) checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data) val rFormulaModel = rFormula.fit(data)
// get feature names from output schema // get labels and feature names from output schema
val schema = rFormulaModel.transform(data).schema val (features, labels) = getFeaturesAndLabels(rFormulaModel, data)
val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
.attributes.get
val features = featureAttrs.map(_.name.get)
// get label names from output schema
val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
.asInstanceOf[NominalAttribute]
val labels = labelAttr.values.get
// assemble and fit the pipeline // assemble and fit the pipeline
val rfc = new RandomForestClassifier() val rfc = new RandomForestClassifier()
...@@ -111,6 +104,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC ...@@ -111,6 +104,7 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC
.setCacheNodeIds(cacheNodeIds) .setCacheNodeIds(cacheNodeIds)
.setProbabilityCol(probabilityCol) .setProbabilityCol(probabilityCol)
.setFeaturesCol(rFormula.getFeaturesCol) .setFeaturesCol(rFormula.getFeaturesCol)
.setLabelCol(rFormula.getLabelCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL) .setPredictionCol(PREDICTED_LABEL_INDEX_COL)
if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment