Skip to content
Snippets Groups Projects
Commit 0f576a57 authored by Dongjoon Hyun's avatar Dongjoon Hyun Committed by Davies Liu
Browse files

[SPARK-15244] [PYTHON] Type of column name created with createDataFrame is not consistent.

## What changes were proposed in this pull request?

**createDataFrame** returns inconsistent types for column names.
```python
>>> from pyspark.sql.types import StructType, StructField, StringType
>>> schema = StructType([StructField(u"col", StringType())])
>>> df1 = spark.createDataFrame([("a",)], schema)
>>> df1.columns # "col" is str
['col']
>>> df2 = spark.createDataFrame([("a",)], [u"col"])
>>> df2.columns # "col" is unicode
[u'col']
```

The reason is only **StructField** has the following code.
```
if not isinstance(name, str):
    name = name.encode('utf-8')
```
This PR adds the same logic into **createDataFrame** for consistency.
```
if isinstance(schema, list):
    schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
```

## How was this patch tested?

Pass the Jenkins test (with new python doctest)

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #13097 from dongjoon-hyun/SPARK-15244.
parent e2efe052
No related branches found
No related tags found
No related merge requests found
......@@ -465,6 +465,8 @@ class SparkSession(object):
return (obj, )
schema = StructType().add("value", datatype)
else:
if isinstance(schema, list):
schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
prepare = lambda obj: obj
if isinstance(data, RDD):
......
......@@ -228,6 +228,13 @@ class SQLTests(ReusedPySparkTestCase):
self.assertRaises(AnalysisException, lambda: df.select(df.c).first())
self.assertRaises(AnalysisException, lambda: df.select(df["c"]).first())
def test_column_name_encoding(self):
"""Ensure that created columns has `str` type consistently."""
columns = self.spark.createDataFrame([('Alice', 1)], ['name', u'age']).columns
self.assertEqual(columns, ['name', 'age'])
self.assertTrue(isinstance(columns[0], str))
self.assertTrue(isinstance(columns[1], str))
def test_explode(self):
from pyspark.sql.functions import explode
d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment