Skip to content
Snippets Groups Projects
Commit f76d2e55 authored by Davies Liu's avatar Davies Liu Committed by Reynold Xin
Browse files

[SPARK-6603] [PySpark] [SQL] add SQLContext.udf and deprecate inferSchema() and applySchema

This PR create an alias for `registerFunction` as `udf.register`, to be consistent with Scala API.

It also deprecated inferSchema() and applySchema(), show an warning for them.

cc rxin

Author: Davies Liu <davies@databricks.com>

Closes #5273 from davies/udf and squashes the following commits:

476e947 [Davies Liu] address comments
c096fdb [Davies Liu] add SQLContext.udf and deprecate inferSchema() and applySchema
parent df355008
No related branches found
No related tags found
No related merge requests found
......@@ -34,7 +34,7 @@ try:
except ImportError:
has_pandas = False
__all__ = ["SQLContext", "HiveContext"]
__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
def _monkey_patch_RDD(sqlCtx):
......@@ -56,6 +56,31 @@ def _monkey_patch_RDD(sqlCtx):
RDD.toDF = toDF
class UDFRegistration(object):
"""Wrapper for register UDF"""
def __init__(self, sqlCtx):
self.sqlCtx = sqlCtx
def register(self, name, f, returnType=StringType()):
"""Registers a lambda function as a UDF so it can be used in SQL statements.
In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not given it default to a string and conversion will automatically
be done. For any other return type, the produced object must match the specified type.
>>> sqlCtx.udf.register("stringLengthString", lambda x: len(x))
>>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
[Row(c0=u'4')]
>>> from pyspark.sql.types import IntegerType
>>> sqlCtx.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
"""
return self.sqlCtx.registerFunction(name, f, returnType)
class SQLContext(object):
"""Main entry point for Spark SQL functionality.
......@@ -118,6 +143,11 @@ class SQLContext(object):
"""
return self._ssql_ctx.getConf(key, defaultValue)
@property
def udf(self):
"""Wrapper for register Python function as UDF """
return UDFRegistration(self)
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a lambda function as a UDF so it can be used in SQL statements.
......@@ -198,14 +228,12 @@ class SQLContext(object):
>>> df.collect()[0]
Row(field1=1, field2=u'row1')
"""
warnings.warn("inferSchema is deprecated, please use createDataFrame instead")
if isinstance(rdd, DataFrame):
raise TypeError("Cannot apply schema to DataFrame")
schema = self._inferSchema(rdd, samplingRatio)
converter = _create_converter(schema)
rdd = rdd.map(converter)
return self.applySchema(rdd, schema)
return self.createDataFrame(rdd, None, samplingRatio)
def applySchema(self, rdd, schema):
"""
......@@ -230,6 +258,7 @@ class SQLContext(object):
>>> df.collect()
[Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
"""
warnings.warn("applySchema is deprecated, please use createDataFrame instead")
if isinstance(rdd, DataFrame):
raise TypeError("Cannot apply schema to DataFrame")
......@@ -237,23 +266,7 @@ class SQLContext(object):
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType, but got %s" % schema)
# take the first few rows to verify schema
rows = rdd.take(10)
# Row() cannot been deserialized by Pyrolite
if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
rdd = rdd.map(tuple)
rows = rdd.take(10)
for row in rows:
_verify_type(row, schema)
# convert python objects to sql data
converter = _python_to_sql_converter(schema)
rdd = rdd.map(converter)
jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return DataFrame(df, self)
return self.createDataFrame(rdd, schema)
def createDataFrame(self, data, schema=None, samplingRatio=None):
"""
......@@ -323,22 +336,42 @@ class SQLContext(object):
if not isinstance(data, RDD):
try:
# data could be list, tuple, generator ...
data = self._sc.parallelize(data)
rdd = self._sc.parallelize(data)
except Exception:
raise ValueError("cannot create an RDD from type: %s" % type(data))
else:
rdd = data
if schema is None:
return self.inferSchema(data, samplingRatio)
schema = self._inferSchema(rdd, samplingRatio)
converter = _create_converter(schema)
rdd = rdd.map(converter)
if isinstance(schema, (list, tuple)):
first = data.first()
first = rdd.first()
if not isinstance(first, (list, tuple)):
raise ValueError("each row in `rdd` should be list or tuple, "
"but got %r" % type(first))
row_cls = Row(*schema)
schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio)
schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio)
return self.applySchema(data, schema)
# take the first few rows to verify schema
rows = rdd.take(10)
# Row() cannot been deserialized by Pyrolite
if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
rdd = rdd.map(tuple)
rows = rdd.take(10)
for row in rows:
_verify_type(row, schema)
# convert python objects to sql data
converter = _python_to_sql_converter(schema)
rdd = rdd.map(converter)
jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return DataFrame(df, self)
def registerDataFrameAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment