From 36da5e323487aa851a45475109185b9b0653db75 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" <joseph@databricks.com> Date: Sat, 16 Apr 2016 11:23:28 -0700 Subject: [PATCH] [SPARK-14605][ML][PYTHON] Changed Python to use unicode UIDs for spark.ml Identifiable ## What changes were proposed in this pull request? Python spark.ml Identifiable classes use UIDs of type str, but they should use unicode (in Python 2.x) to match Java. This could be a problem if someone created a class in Java with odd unicode characters, saved it, and loaded it in Python. This PR: Use unicode everywhere in Python. ## How was this patch tested? Updated persistence unit test to check uid type Author: Joseph K. Bradley <joseph@databricks.com> Closes #12368 from jkbradley/python-uid-unicode. --- python/pyspark/ml/param/__init__.py | 3 ++- python/pyspark/ml/tests.py | 2 ++ python/pyspark/ml/util.py | 5 +++-- python/pyspark/ml/wrapper.py | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 9f0b063aac..40d8300625 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -485,10 +485,11 @@ class Params(Identifiable): Changes the uid of this instance. This updates both the stored uid and the parent uid of params and param maps. This is used by persistence (loading). - :param newUid: new uid to use + :param newUid: new uid to use, which is converted to unicode :return: same instance, but with the uid and Param.parent values updated, including within param maps """ + newUid = unicode(newUid) self.uid = newUid newDefaultParamMap = dict() newParamMap = dict() diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index d595eff5b4..a7a9868bac 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -621,6 +621,8 @@ class PersistenceTest(PySparkTestCase): lr_path = path + "/lr" lr.save(lr_path) lr2 = LinearRegression.load(lr_path) + self.assertEqual(lr.uid, lr2.uid) + self.assertEqual(type(lr.uid), type(lr2.uid)) self.assertEqual(lr2.uid, lr2.maxIter.parent, "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" % (lr2.uid, lr2.maxIter.parent)) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 9dfcef0e40..841bfb47e1 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -21,6 +21,7 @@ from functools import wraps if sys.version > '3': basestring = str + unicode = str from pyspark import SparkContext, since from pyspark.mllib.common import inherit_doc @@ -67,10 +68,10 @@ class Identifiable(object): @classmethod def _randomUID(cls): """ - Generate a unique id for the object. The default implementation + Generate a unique unicode id for the object. The default implementation concatenates the class name, "_", and 12 random hex chars. """ - return cls.__name__ + "_" + uuid.uuid4().hex[12:] + return unicode(cls.__name__ + "_" + uuid.uuid4().hex[12:]) @inherit_doc diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 055a2816f8..fef626c7fa 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -254,7 +254,7 @@ class JavaModel(JavaTransformer, Model): """ super(JavaModel, self).__init__(java_model) if java_model is not None: - self.uid = java_model.uid() + self._resetUid(java_model.uid()) def copy(self, extra=None): """ -- GitLab