Skip to content
Snippets Groups Projects
Commit d1ca634d authored by Holden Karau's avatar Holden Karau Committed by Davies Liu
Browse files

[SPARK-12300] [SQL] [PYSPARK] fix schema inferance on local collections

Current schema inference for local python collections halts as soon as there are no NullTypes. This is different than when we specify a sampling ratio of 1.0 on a distributed collection. This could result in incomplete schema information.

Author: Holden Karau <holden@us.ibm.com>

Closes #10275 from holdenk/SPARK-12300-fix-schmea-inferance-on-local-collections.
parent aa48164a
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,7 @@
import sys
import warnings
import json
from functools import reduce
if sys.version >= '3':
basestring = unicode = str
......@@ -236,14 +237,9 @@ class SQLContext(object):
if type(first) is dict:
warnings.warn("inferring schema from dict is deprecated,"
"please use pyspark.sql.Row instead")
schema = _infer_schema(first)
schema = reduce(_merge_type, map(_infer_schema, data))
if _has_nulltype(schema):
for r in data:
schema = _merge_type(schema, _infer_schema(r))
if not _has_nulltype(schema):
break
else:
raise ValueError("Some of types cannot be determined after inferring")
raise ValueError("Some of types cannot be determined after inferring")
return schema
def _inferSchema(self, rdd, samplingRatio=None):
......
......@@ -353,6 +353,17 @@ class SQLTests(ReusedPySparkTestCase):
df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
def test_infer_schema_to_local(self):
input = [{"a": 1}, {"b": "coffee"}]
rdd = self.sc.parallelize(input)
df = self.sqlCtx.createDataFrame(input)
df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
self.assertEqual(df.schema, df2.schema)
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
rdd = self.sc.parallelize(d)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment