diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index cbd508ae79acd4cadca3bab6832a3831c14e3e8c..7cbcccf2720a3510d858a93c84de2b9be1c4837b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -135,11 +135,6 @@ final class OneVsRestModel private[ml] ( @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { - /** A Python-friendly auxiliary constructor. */ - private[ml] def this(uid: String, models: JList[_ <: ClassificationModel[_, _]]) = { - this(uid, Metadata.empty, models.asScala.toArray) - } - /** @group setParam */ @Since("2.1.0") def setFeaturesCol(value: String): this.type = set(featuresCol, value) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7e6e143523387659fbd600726edc154417520802..9d359427f27a60406c7fd79f688504fb0ea1bd21 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -54,7 +54,10 @@ object MimaExcludes { // [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$10"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$11") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$11"), + + // [SPARK-17161] Removing Python-friendly constructors not needed + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this") ) // Exclude rules for 2.1.x diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index f10556ca92290d93be2b0b9f72a25683774e6566..d41fc81fd75d4cc3b472d5e30162de6f5a6461a6 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1517,6 +1517,11 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4))]).toDF() >>> model.transform(test2).head().prediction 2.0 + >>> model_path = temp_path + "/ovr_model" + >>> model.save(model_path) + >>> model2 = OneVsRestModel.load(model_path) + >>> model2.transform(test0).head().prediction + 1.0 .. versionadded:: 2.0.0 """ @@ -1759,9 +1764,13 @@ class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable): :return: Java object equivalent to this instance. """ + sc = SparkContext._active_spark_context java_models = [model._to_java() for model in self.models] + java_models_array = JavaWrapper._new_java_array( + java_models, sc._gateway.jvm.org.apache.spark.ml.classification.ClassificationModel) + metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata") _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel", - self.uid, java_models) + self.uid, metadata.empty(), java_models_array) _java_obj.set("classifier", self.getClassifier()._to_java()) _java_obj.set("featuresCol", self.getFeaturesCol()) _java_obj.set("labelCol", self.getLabelCol()) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 68f5bc30ac57f854485c682dad3c30926b1bb313..53204cde29b74bd394d94cd19c186b3074fe0316 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -60,8 +60,8 @@ from pyspark.ml.recommendation import ALS from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \ GeneralizedLinearRegression from pyspark.ml.tuning import * -from pyspark.ml.wrapper import JavaParams -from pyspark.ml.common import _java2py +from pyspark.ml.wrapper import JavaParams, JavaWrapper +from pyspark.ml.common import _java2py, _py2java from pyspark.serializers import PickleSerializer from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import rand @@ -1620,6 +1620,42 @@ class MatrixUDTTests(MLlibTestCase): raise ValueError("Expected a matrix but got type %r" % type(m)) +class WrapperTests(MLlibTestCase): + + def test_new_java_array(self): + # test array of strings + str_list = ["a", "b", "c"] + java_class = self.sc._gateway.jvm.java.lang.String + java_array = JavaWrapper._new_java_array(str_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), str_list) + # test array of integers + int_list = [1, 2, 3] + java_class = self.sc._gateway.jvm.java.lang.Integer + java_array = JavaWrapper._new_java_array(int_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), int_list) + # test array of floats + float_list = [0.1, 0.2, 0.3] + java_class = self.sc._gateway.jvm.java.lang.Double + java_array = JavaWrapper._new_java_array(float_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), float_list) + # test array of bools + bool_list = [False, True, True] + java_class = self.sc._gateway.jvm.java.lang.Boolean + java_array = JavaWrapper._new_java_array(bool_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), bool_list) + # test array of Java DenseVectors + v1 = DenseVector([0.0, 1.0]) + v2 = DenseVector([1.0, 0.0]) + vec_java_list = [_py2java(self.sc, v1), _py2java(self.sc, v2)] + java_class = self.sc._gateway.jvm.org.apache.spark.ml.linalg.DenseVector + java_array = JavaWrapper._new_java_array(vec_java_list, java_class) + self.assertEqual(_java2py(self.sc, java_array), [v1, v2]) + # test empty array + java_class = self.sc._gateway.jvm.java.lang.Integer + java_array = JavaWrapper._new_java_array([], java_class) + self.assertEqual(_java2py(self.sc, java_array), []) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 13b75e99192215cd1f29db815adfd8bc92029fcc..80a0b31cd88d9459c13d325105aec56d1c8a0b16 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -16,6 +16,9 @@ # from abc import ABCMeta, abstractmethod +import sys +if sys.version >= '3': + xrange = range from pyspark import SparkContext from pyspark.sql import DataFrame @@ -59,6 +62,32 @@ class JavaWrapper(object): java_args = [_py2java(sc, arg) for arg in args] return java_obj(*java_args) + @staticmethod + def _new_java_array(pylist, java_class): + """ + Create a Java array of given java_class type. Useful for + calling a method with a Scala Array from Python with Py4J. + + :param pylist: + Python list to convert to a Java Array. + :param java_class: + Java class to specify the type of Array. Should be in the + form of sc._gateway.jvm.* (sc is a valid Spark Context). + :return: + Java Array of converted pylist. + + Example primitive Java classes: + - basestring -> sc._gateway.jvm.java.lang.String + - int -> sc._gateway.jvm.java.lang.Integer + - float -> sc._gateway.jvm.java.lang.Double + - bool -> sc._gateway.jvm.java.lang.Boolean + """ + sc = SparkContext._active_spark_context + java_array = sc._gateway.new_array(java_class, len(pylist)) + for i in xrange(len(pylist)): + java_array[i] = pylist[i] + return java_array + @inherit_doc class JavaParams(JavaWrapper, Params):