diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 67d037ed6e0249dac219e73b9ac365a37da55279..bd965acf56944d3708b1ff3dc6e325901ff4340e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -99,6 +99,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg val aft = new AFTSurvivalRegression() .setCensorCol(censorCol) .setFitIntercept(rFormula.hasIntercept) + .setFeaturesCol(rFormula.getFeaturesCol) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, aft)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala index b654233a893608b48cce5c0d7da6fb77c6884e72..b708702959829c9f749e13653d211a60d94722fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala @@ -85,6 +85,7 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp .setK(k) .setMaxIter(maxIter) .setTol(tol) + .setFeaturesCol(rFormula.getFeaturesCol) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, gm)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 35313258f940e707977e37a6694af6f0604759a8..b1bb577e1ffe43c8040d0b8207032c0f69f0eed5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -89,6 +89,7 @@ private[r] object GeneralizedLinearRegressionWrapper .setMaxIter(maxIter) .setWeightCol(weightCol) .setRegParam(regParam) + .setFeaturesCol(rFormula.getFeaturesCol) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, glr)) .fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala index 2ed7d7b770cc97dcfd27e5c3a9356a77d78d79f7..48632316f39508afc43921e39bcb30339e8ebac8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala @@ -75,6 +75,7 @@ private[r] object IsotonicRegressionWrapper .setIsotonic(isotonic) .setFeatureIndex(featureIndex) .setWeightCol(weightCol) + .setFeaturesCol(rFormula.getFeaturesCol) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, isotonicRegression)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala index 8616a8c01e5ac243068da41eb579748a73d056c4..ea9458525aa31472fab4e3f79d9c3de2126cb238 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -86,6 +86,7 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] { .setK(k) .setMaxIter(maxIter) .setInitMode(initMode) + .setFeaturesCol(rFormula.getFeaturesCol) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, kMeans)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index f2cb24b96404b82cfbe03dd3043397f15874eed1..d1a39fea76ef863e2369d6e50e39f0bb77a8f57b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -73,6 +73,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { val naiveBayes = new NaiveBayes() .setSmoothing(smoothing) .setModelType("bernoulli") + .setFeaturesCol(rFormula.getFeaturesCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) val idxToStr = new IndexToString() .setInputCol(PREDICTED_LABEL_INDEX_COL) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala index 6a435992e3b359724fe8fdbfd0dc43929423decd..379007c4d948dedd738044b2b228092973f5bf3c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala @@ -19,14 +19,15 @@ package org.apache.spark.ml.r import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.Dataset object RWrapperUtils extends Logging { /** * DataFrame column check. - * When loading data, default columns "features" and "label" will be added. And these two names - * would conflict with RFormula default feature and label column names. + * When loading libsvm data, default columns "features" and "label" will be added. + * And "features" would conflict with RFormula default feature column names. * Here is to change the column name to avoid "column already exists" error. * * @param rFormula RFormula instance @@ -34,38 +35,11 @@ object RWrapperUtils extends Logging { * @return Unit */ def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = { - if (data.schema.fieldNames.contains(rFormula.getLabelCol)) { - val newLabelName = convertToUniqueName(rFormula.getLabelCol, data.schema.fieldNames) - logWarning( - s"data containing ${rFormula.getLabelCol} column, using new name $newLabelName instead") - rFormula.setLabelCol(newLabelName) - } - if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) { - val newFeaturesName = convertToUniqueName(rFormula.getFeaturesCol, data.schema.fieldNames) + val newFeaturesName = s"${Identifiable.randomUID(rFormula.getFeaturesCol)}" logWarning(s"data containing ${rFormula.getFeaturesCol} column, " + s"using new name $newFeaturesName instead") rFormula.setFeaturesCol(newFeaturesName) } } - - /** - * Convert conflicting name to be an unique name. - * Appending a sequence number, like originalName_output1 - * and incrementing until it is not already there - * - * @param originalName Original name - * @param fieldNames Array of field names in existing schema - * @return String - */ - def convertToUniqueName(originalName: String, fieldNames: Array[String]): String = { - var counter = 1 - var newName = originalName + "_output" - - while (fieldNames.contains(newName)) { - newName = originalName + "_output" + counter - counter += 1 - } - newName - } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala index ddc24cb3a64811dc46135cb7a9e8ab7debc8e810..27b03918d951ef1db2f86f1c020b4b44e9a2c2d4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala @@ -35,22 +35,14 @@ class RWrapperUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { // after checking, model build is ok RWrapperUtils.checkDataColumns(rFormula, data) - assert(rFormula.getLabelCol == "label_output") - assert(rFormula.getFeaturesCol == "features_output") + assert(rFormula.getLabelCol == "label") + assert(rFormula.getFeaturesCol.startsWith("features_")) val model = rFormula.fit(data) assert(model.isInstanceOf[RFormulaModel]) - assert(model.getLabelCol == "label_output") - assert(model.getFeaturesCol == "features_output") - } - - test("generate unique name by appending a sequence number") { - val originalName = "label" - val fieldNames = Array("label_output", "label_output1", "label_output2") - val newName = RWrapperUtils.convertToUniqueName(originalName, fieldNames) - - assert(newName === "label_output3") + assert(model.getLabelCol == "label") + assert(model.getFeaturesCol.startsWith("features_")) } }