diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index f406f8c426d0c18c5d9f22cafc5edd74482b1ed6..38176b96ba2ed6f41e58e0decb260ddd3a1e7129 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -46,6 +46,10 @@ abstract class PipelineStage extends Params with Logging { * * Check transform validity and derive the output schema from the input schema. * + * We check validity for interactions between parameters during `transformSchema` and + * raise an exception if any parameter value is invalid. Parameter value checks which + * do not depend on other parameters are handled by `Param.validate()`. + * * Typical implementation should first conduct verification on schema change and parameter * validity, including complex parameter interaction checks. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 52f93f5a6b345aa12f55bf746ea04a52102c0888..ca5223133317cd7c549e43186bcee1824e603a0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -203,6 +203,12 @@ class GBTClassificationModel private[ml]( @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees + /** + * Number of trees in ensemble + */ + @Since("2.0.0") + val getNumTrees: Int = trees.length + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index fe29926e0d994962971c913278b093725a0210f9..41b84f481633cf6fa1f3e91f63a0ec589c60d342 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -40,7 +40,7 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, lit} -import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils @@ -176,8 +176,12 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas } } - override def validateParams(): Unit = { + override protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { checkThresholdConsistency() + super.validateAndTransformSchema(schema, fitting, featuresDataType) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 907c73e2e4d0ad59964daed40342a768a373e41e..d151213f9edd81edc8a9016f9c492c22f10caa1c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -158,7 +158,7 @@ class RandomForestClassificationModel private[ml] ( @Since("1.6.0") override val numFeatures: Int, @Since("1.5.0") override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] - with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel] + with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel] with MLWritable with Serializable { require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.") @@ -221,15 +221,6 @@ class RandomForestClassificationModel private[ml] ( } } - /** - * Number of trees in ensemble - * - * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0 - */ - // TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams - @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0") - val numTrees: Int = trees.length - @Since("1.4.0") override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 653fa41124f88dda439914a53bfaa815e2d9603d..7cd0f159c6be72d10ce178bce0cde4c4fced6f58 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -216,13 +216,6 @@ final class ChiSqSelectorModel private[ml] ( @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) - /** - * @group setParam - */ - @Since("1.6.0") - @deprecated("labelCol is not used by ChiSqSelectorModel.", "2.0.0") - def setLabelCol(value: String): this.type = set(labelCol, value) - @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema, logging = true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 96206e0b7ad88a74feb78484f37f3e7d6d2e41f6..5bd8ebe0987a9bd85e24ca2fff2b565a55e5f2bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -546,21 +546,6 @@ trait Params extends Identifiable with Serializable { .map(m => m.invoke(this).asInstanceOf[Param[_]]) } - /** - * Validates parameter values stored internally. - * Raise an exception if any parameter value is invalid. - * - * This only needs to check for interactions between parameters. - * Parameter value checks which do not depend on other parameters are handled by - * `Param.validate()`. This method does not handle input/output column parameters; - * those are checked during schema validation. - * @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema - */ - @deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0") - def validateParams(): Unit = { - // Do nothing by default. Override to handle Param interactions. - } - /** * Explains a param. * @param param input param, must belong to this instance. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index ed2d05525d611c48c65cbd706f6798e9cc0a793e..6d8159aa3bdcf926dff89120ed127595cf00f5f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -183,6 +183,12 @@ class GBTRegressionModel private[ml]( @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees + /** + * Number of trees in ensemble + */ + @Since("2.0.0") + val getNumTrees: Int = trees.length + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index eb4e38cc83c1933f9a490d850f2ae1d14d7410a4..19ddf36a718c4fd529ddeca286bcc097e887acae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -611,9 +611,6 @@ class LinearRegressionSummary private[regression] ( private val privateModel: LinearRegressionModel, private val diagInvAtWA: Array[Double]) extends Serializable { - @deprecated("The model field is deprecated and will be removed in 2.1.0.", "2.0.0") - val model: LinearRegressionModel = privateModel - @transient private val metrics = new RegressionMetrics( predictions .select(col(predictionCol), col(labelCol).cast(DoubleType)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index d60f05eed58d949e69ef31fbe9aba2985de5237e..90d89c51c5740715986af8209cef66fa1ab0faf7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -145,7 +145,7 @@ class RandomForestRegressionModel private[ml] ( private val _trees: Array[DecisionTreeRegressionModel], override val numFeatures: Int) extends PredictionModel[Vector, RandomForestRegressionModel] - with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with RandomForestRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel] with MLWritable with Serializable { require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.") @@ -182,14 +182,6 @@ class RandomForestRegressionModel private[ml] ( _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees } - /** - * Number of trees in ensemble - * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0 - */ - // TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams - @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0") - val numTrees: Int = trees.length - @Since("1.4.0") override def copy(extra: ParamMap): RandomForestRegressionModel = { copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index d3cbc363799a5f0b8d88b294b2b194a1ce85272c..0d6e9034e5ce42c42d19834c0d1006804dc9c069 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -95,11 +95,6 @@ private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] { /** Trees in this ensemble. Warning: These have null parent Estimators. */ def trees: Array[M] - /** - * Number of trees in ensemble - */ - val getNumTrees: Int = trees.length - /** Weights for each tree, zippable with [[trees]] */ def treeWeights: Array[Double] diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 40510ad804ef0dbacde708b337e8b8160fc9b21f..83ab4b5da87be732eb3772aab9f595ce4bfa9d9f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -319,8 +319,32 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { } } -/** Used for [[RandomForestParams]] */ -private[ml] trait HasFeatureSubsetStrategy extends Params { +/** + * Parameters for Random Forest algorithms. + */ +private[ml] trait RandomForestParams extends TreeEnsembleParams { + + /** + * Number of trees to train (>= 1). + * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * TODO: Change to always do bootstrapping (simpler). SPARK-7130 + * (default = 20) + * + * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams) + * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms + * are a bit different. + * @group param + */ + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", + ParamValidators.gtEq(1)) + + setDefault(numTrees -> 20) + + /** @group setParam */ + def setNumTrees(value: Int): this.type = set(numTrees, value) + + /** @group getParam */ + final def getNumTrees: Int = $(numTrees) /** * The number of features to consider for splits at each tree node. @@ -366,38 +390,6 @@ private[ml] trait HasFeatureSubsetStrategy extends Params { final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase } -/** - * Used for [[RandomForestParams]]. - * This is separated out from [[RandomForestParams]] because of an issue with the - * `numTrees` method conflicting with this Param in the Estimator. - */ -private[ml] trait HasNumTrees extends Params { - - /** - * Number of trees to train (>= 1). - * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. - * TODO: Change to always do bootstrapping (simpler). SPARK-7130 - * (default = 20) - * @group param - */ - final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", - ParamValidators.gtEq(1)) - - setDefault(numTrees -> 20) - - /** @group setParam */ - def setNumTrees(value: Int): this.type = set(numTrees, value) - - /** @group getParam */ - final def getNumTrees: Int = $(numTrees) -} - -/** - * Parameters for Random Forest algorithms. - */ -private[ml] trait RandomForestParams extends TreeEnsembleParams - with HasFeatureSubsetStrategy with HasNumTrees - private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = @@ -407,21 +399,15 @@ private[spark] object RandomForestParams { private[ml] trait RandomForestClassifierParams extends RandomForestParams with TreeClassifierParams -private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams - with HasFeatureSubsetStrategy with TreeClassifierParams - private[ml] trait RandomForestRegressorParams extends RandomForestParams with TreeRegressorParams -private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams - with HasFeatureSubsetStrategy with TreeRegressorParams - /** * Parameters for Gradient-Boosted Tree algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /* TODO: Add this doc when we add this param. SPARK-7132 * Threshold for stopping early when runWithValidation is used. @@ -434,24 +420,26 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") // validationTol -> 1e-5 - setDefault(maxIter -> 20, stepSize -> 0.1) - /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) /** - * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each - * estimator. + * Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking + * the contribution of each estimator. * (default = 0.1) - * @group setParam + * @group param */ + final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size " + + "(a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.", + ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) + + /** @group getParam */ + final def getStepSize: Double = $(stepSize) + + /** @group setParam */ def setStepSize(value: Double): this.type = set(stepSize, value) - override def validateParams(): Unit = { - require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)( - getStepSize), "GBT parameter stepSize should be in interval (0, 1], " + - s"but it given invalid value $getStepSize.") - } + setDefault(maxIter -> 20, stepSize -> 0.1) /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ private[ml] def getOldBoostingStrategy( diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 5b7e5ec75c8424e931e591bf65fa3bb0141644d4..bbb9886391697589fe21bd2b21e1070f4f3b0a94 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -46,7 +46,7 @@ private[util] sealed trait BaseReadWrite { * Sets the Spark SQLContext to use for saving/loading. */ @Since("1.6.0") - @deprecated("Use session instead", "2.0.0") + @deprecated("Use session instead, This method will be removed in 2.2.0.", "2.0.0") def context(sqlContext: SQLContext): this.type = { optionSparkSession = Option(sqlContext.sparkSession) this diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 3492709677d4ff20d4237dc31cadf1c6c5b20b1d..7c36745ab213b9466fde0ce314fabf0ea4284aae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -70,6 +70,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext ParamsSuite.checkParams(model) } + test("GBT parameter stepSize should be in interval (0, 1]") { + withClue("GBT parameter stepSize should be in interval (0, 1]") { + intercept[IllegalArgumentException] { + new GBTClassifier().setStepSize(10) + } + } + } + test("Binary classification with continuous features: Log Loss") { val categoricalFeatures = Map.empty[Int, Int] testCombinations.foreach { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index e360542eae2ab57e7ad05e4316c75cf24492c046..9c4c59a5e60fa54bc960dd197d9d8e89d0edb6e8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -192,6 +192,12 @@ class LogisticRegressionSuite } } // thresholds and threshold must be consistent: values + withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") { + intercept[IllegalArgumentException] { + lr2.fit(smallBinaryDataset, + lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0)) + } + } withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") { intercept[IllegalArgumentException] { val lr2model = lr2.fit(smallBinaryDataset, diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 12f7ed202b9dbfffc2d02e2caa7a27645e43a359..84014014f2f59bf4100456eaee8117c8d4fb83c1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -867,6 +867,36 @@ object MimaExcludes { // [SPARK-12221] Add CPU time to metrics ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this") + ) ++ Seq( + // [SPARK-18481] ML 2.1 QA: Remove deprecated methods for ML + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.PipelineStage.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.JavaParams.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.Params.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegression.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassifier.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.setLabelCol"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressor.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.model"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassifier"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassifier"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassificationModel"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressor"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressor"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressionModel"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.getNumTrees"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.getNumTrees"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy") ) } diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 7d39c30122350f8ff393c0687994aeff76cabbc0..bec4b289521023e10553343e5d194c6c28ed2cde 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -78,7 +78,14 @@ class MLWriter(object): raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) def context(self, sqlContext): - """Sets the SQL context to use for saving.""" + """ + Sets the SQL context to use for saving. + .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + """ + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + def session(self, sparkSession): + """Sets the Spark Session to use for saving.""" raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) @@ -105,10 +112,19 @@ class JavaMLWriter(MLWriter): return self def context(self, sqlContext): - """Sets the SQL context to use for saving.""" + """ + Sets the SQL context to use for saving. + .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + """ + warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.") self._jwrite.context(sqlContext._ssql_ctx) return self + def session(self, sparkSession): + """Sets the Spark Session to use for saving.""" + self._jwrite.session(sparkSession._jsparkSession) + return self + @inherit_doc class MLWritable(object): @@ -155,7 +171,14 @@ class MLReader(object): raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) def context(self, sqlContext): - """Sets the SQL context to use for loading.""" + """ + Sets the SQL context to use for loading. + .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + """ + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + + def session(self, sparkSession): + """Sets the Spark Session to use for loading.""" raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) @@ -180,10 +203,19 @@ class JavaMLReader(MLReader): return self._clazz._from_java(java_obj) def context(self, sqlContext): - """Sets the SQL context to use for loading.""" + """ + Sets the SQL context to use for loading. + .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + """ + warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.") self._jread.context(sqlContext._ssql_ctx) return self + def session(self, sparkSession): + """Sets the Spark Session to use for loading.""" + self._jread.session(sparkSession._jsparkSession) + return self + @classmethod def _java_loader_class(cls, clazz): """