diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 13342073b49882d63445141cb5712af3ad15b795..043c25cf9feb4ae72dbc0ad59bd486c2bb9e450a 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -330,7 +330,7 @@ class Params(Identifiable): Tests whether this instance contains a param with a given (string) name. """ - if isinstance(paramName, str): + if isinstance(paramName, basestring): p = getattr(self, paramName, None) return isinstance(p, Param) else: @@ -413,7 +413,7 @@ class Params(Identifiable): if isinstance(param, Param): self._shouldOwn(param) return param - elif isinstance(param, str): + elif isinstance(param, basestring): return self.getParam(param) else: raise ValueError("Cannot resolve %r as a param." % param) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6076b3c2f26a63d3a59d2075625dab31c71c9023..509698f6014ebb3987462fba8803e02307e5bac0 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -352,6 +353,20 @@ class ParamTests(PySparkTestCase): testParams = TestParams() self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) self.assertFalse(testParams.hasParam("notAParameter")) + self.assertTrue(testParams.hasParam(u"maxIter")) + + def test_resolveparam(self): + testParams = TestParams() + self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter) + self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter) + + self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter) + if sys.version_info[0] >= 3: + # In Python 3, it is allowed to get/set attributes with non-ascii characters. + e_cls = AttributeError + else: + e_cls = UnicodeEncodeError + self.assertRaises(e_cls, lambda: testParams._resolveParam(u"ì•„")) def test_params(self): testParams = TestParams() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1cea130d918ad618fabfdb54a4f7ca3ea9af49e3..8f88545443c75c0452972ab76886de5d31206446 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -748,7 +748,7 @@ class DataFrame(object): +---+-----+ """ - if not isinstance(col, str): + if not isinstance(col, basestring): raise ValueError("col must be a string, but got %r" % type(col)) if not isinstance(fractions, dict): raise ValueError("fractions must be a dict but got %r" % type(fractions)) @@ -1664,18 +1664,18 @@ class DataFrame(object): Added support for multiple columns. """ - if not isinstance(col, (str, list, tuple)): + if not isinstance(col, (basestring, list, tuple)): raise ValueError("col should be a string, list or tuple, but got %r" % type(col)) - isStr = isinstance(col, str) + isStr = isinstance(col, basestring) if isinstance(col, tuple): col = list(col) - elif isinstance(col, str): + elif isStr: col = [col] for c in col: - if not isinstance(c, str): + if not isinstance(c, basestring): raise ValueError("columns should be strings, but got %r" % type(c)) col = _to_list(self._sc, col) @@ -1707,9 +1707,9 @@ class DataFrame(object): :param col2: The name of the second column :param method: The correlation method. Currently only supports "pearson" """ - if not isinstance(col1, str): + if not isinstance(col1, basestring): raise ValueError("col1 should be a string.") - if not isinstance(col2, str): + if not isinstance(col2, basestring): raise ValueError("col2 should be a string.") if not method: method = "pearson" @@ -1727,9 +1727,9 @@ class DataFrame(object): :param col1: The name of the first column :param col2: The name of the second column """ - if not isinstance(col1, str): + if not isinstance(col1, basestring): raise ValueError("col1 should be a string.") - if not isinstance(col2, str): + if not isinstance(col2, basestring): raise ValueError("col2 should be a string.") return self._jdf.stat().cov(col1, col2) @@ -1749,9 +1749,9 @@ class DataFrame(object): :param col2: The name of the second column. Distinct items will make the column names of the DataFrame. """ - if not isinstance(col1, str): + if not isinstance(col1, basestring): raise ValueError("col1 should be a string.") - if not isinstance(col2, str): + if not isinstance(col2, basestring): raise ValueError("col2 should be a string.") return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1bc889c3f45c703cc222859740a5c6a84a7ed867..4d65abc11eaf9b25aa096a2aa82672856be6670b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1140,11 +1140,12 @@ class SQLTests(ReusedPySparkTestCase): def test_approxQuantile(self): df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF() - aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1) - self.assertTrue(isinstance(aq, list)) - self.assertEqual(len(aq), 3) + for f in ["a", u"a"]: + aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1) + self.assertTrue(isinstance(aq, list)) + self.assertEqual(len(aq), 3) self.assertTrue(all(isinstance(q, float) for q in aq)) - aqs = df.stat.approxQuantile(["a", "b"], [0.1, 0.5, 0.9], 0.1) + aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1) self.assertTrue(isinstance(aqs, list)) self.assertEqual(len(aqs), 2) self.assertTrue(isinstance(aqs[0], list)) @@ -1153,7 +1154,7 @@ class SQLTests(ReusedPySparkTestCase): self.assertTrue(isinstance(aqs[1], list)) self.assertEqual(len(aqs[1]), 3) self.assertTrue(all(isinstance(q, float) for q in aqs[1])) - aqt = df.stat.approxQuantile(("a", "b"), [0.1, 0.5, 0.9], 0.1) + aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1) self.assertTrue(isinstance(aqt, list)) self.assertEqual(len(aqt), 2) self.assertTrue(isinstance(aqt[0], list)) @@ -1169,17 +1170,22 @@ class SQLTests(ReusedPySparkTestCase): def test_corr(self): import math df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() - corr = df.stat.corr("a", "b") + corr = df.stat.corr(u"a", "b") self.assertTrue(abs(corr - 0.95734012) < 1e-6) + def test_sampleby(self): + df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF() + sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0) + self.assertTrue(sampled.count() == 3) + def test_cov(self): df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() - cov = df.stat.cov("a", "b") + cov = df.stat.cov(u"a", "b") self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) def test_crosstab(self): df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() - ct = df.stat.crosstab("a", "b").collect() + ct = df.stat.crosstab(u"a", "b").collect() ct = sorted(ct, key=lambda x: x[0]) for i, row in enumerate(ct): self.assertEqual(row[0], str(i))