diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 795ef0dbc4c47d4f3ee603eb8b80e45b6a52ef45..80939a1f8ab1e0257b99690b0b361ebd14f15660 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -34,7 +34,7 @@ try: except ImportError: has_pandas = False -__all__ = ["SQLContext", "HiveContext"] +__all__ = ["SQLContext", "HiveContext", "UDFRegistration"] def _monkey_patch_RDD(sqlCtx): @@ -56,6 +56,31 @@ def _monkey_patch_RDD(sqlCtx): RDD.toDF = toDF +class UDFRegistration(object): + """Wrapper for register UDF""" + + def __init__(self, sqlCtx): + self.sqlCtx = sqlCtx + + def register(self, name, f, returnType=StringType()): + """Registers a lambda function as a UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + >>> sqlCtx.udf.register("stringLengthString", lambda x: len(x)) + >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + [Row(c0=u'4')] + + >>> from pyspark.sql.types import IntegerType + >>> sqlCtx.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + """ + return self.sqlCtx.registerFunction(name, f, returnType) + + class SQLContext(object): """Main entry point for Spark SQL functionality. @@ -118,6 +143,11 @@ class SQLContext(object): """ return self._ssql_ctx.getConf(key, defaultValue) + @property + def udf(self): + """Wrapper for register Python function as UDF """ + return UDFRegistration(self) + def registerFunction(self, name, f, returnType=StringType()): """Registers a lambda function as a UDF so it can be used in SQL statements. @@ -198,14 +228,12 @@ class SQLContext(object): >>> df.collect()[0] Row(field1=1, field2=u'row1') """ + warnings.warn("inferSchema is deprecated, please use createDataFrame instead") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") - schema = self._inferSchema(rdd, samplingRatio) - converter = _create_converter(schema) - rdd = rdd.map(converter) - return self.applySchema(rdd, schema) + return self.createDataFrame(rdd, None, samplingRatio) def applySchema(self, rdd, schema): """ @@ -230,6 +258,7 @@ class SQLContext(object): >>> df.collect() [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] """ + warnings.warn("applySchema is deprecated, please use createDataFrame instead") if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") @@ -237,23 +266,7 @@ class SQLContext(object): if not isinstance(schema, StructType): raise TypeError("schema should be StructType, but got %s" % schema) - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) - rows = rdd.take(10) - - for row in rows: - _verify_type(row, schema) - - # convert python objects to sql data - converter = _python_to_sql_converter(schema) - rdd = rdd.map(converter) - - jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return DataFrame(df, self) + return self.createDataFrame(rdd, schema) def createDataFrame(self, data, schema=None, samplingRatio=None): """ @@ -323,22 +336,42 @@ class SQLContext(object): if not isinstance(data, RDD): try: # data could be list, tuple, generator ... - data = self._sc.parallelize(data) + rdd = self._sc.parallelize(data) except Exception: raise ValueError("cannot create an RDD from type: %s" % type(data)) + else: + rdd = data if schema is None: - return self.inferSchema(data, samplingRatio) + schema = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(schema) + rdd = rdd.map(converter) if isinstance(schema, (list, tuple)): - first = data.first() + first = rdd.first() if not isinstance(first, (list, tuple)): raise ValueError("each row in `rdd` should be list or tuple, " "but got %r" % type(first)) row_cls = Row(*schema) - schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio) + schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) - return self.applySchema(data, schema) + # take the first few rows to verify schema + rows = rdd.take(10) + # Row() cannot been deserialized by Pyrolite + if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': + rdd = rdd.map(tuple) + rows = rdd.take(10) + + for row in rows: + _verify_type(row, schema) + + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) def registerDataFrameAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog.