diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 9f0b063aace5af0b927312fce84b74bca5e02539..40d830062581b41426983268ec580e45b76fdb94 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -485,10 +485,11 @@ class Params(Identifiable): Changes the uid of this instance. This updates both the stored uid and the parent uid of params and param maps. This is used by persistence (loading). - :param newUid: new uid to use + :param newUid: new uid to use, which is converted to unicode :return: same instance, but with the uid and Param.parent values updated, including within param maps """ + newUid = unicode(newUid) self.uid = newUid newDefaultParamMap = dict() newParamMap = dict() diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index d595eff5b47f800797b2a1d33f9cbbf942bac340..a7a9868baccb301c588101e1d710e18a8985aeae 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -621,6 +621,8 @@ class PersistenceTest(PySparkTestCase): lr_path = path + "/lr" lr.save(lr_path) lr2 = LinearRegression.load(lr_path) + self.assertEqual(lr.uid, lr2.uid) + self.assertEqual(type(lr.uid), type(lr2.uid)) self.assertEqual(lr2.uid, lr2.maxIter.parent, "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)" % (lr2.uid, lr2.maxIter.parent)) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 9dfcef0e40d67bdc9d2ffbdce863ffa30fb43c16..841bfb47e1b9d73639fc5bc743a8adf95f12dfd7 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -21,6 +21,7 @@ from functools import wraps if sys.version > '3': basestring = str + unicode = str from pyspark import SparkContext, since from pyspark.mllib.common import inherit_doc @@ -67,10 +68,10 @@ class Identifiable(object): @classmethod def _randomUID(cls): """ - Generate a unique id for the object. The default implementation + Generate a unique unicode id for the object. The default implementation concatenates the class name, "_", and 12 random hex chars. """ - return cls.__name__ + "_" + uuid.uuid4().hex[12:] + return unicode(cls.__name__ + "_" + uuid.uuid4().hex[12:]) @inherit_doc diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 055a2816f8d753e137feeeb9f93b96f81b0469c0..fef626c7fa2cba11831761bc6fdb0215ff31c8cc 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -254,7 +254,7 @@ class JavaModel(JavaTransformer, Model): """ super(JavaModel, self).__init__(java_model) if java_model is not None: - self.uid = java_model.uid() + self._resetUid(java_model.uid()) def copy(self, extra=None): """