diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 9d29ef4839a437a3715c3f393aed98d8e5f171f0..db4bcbece2c1b3ea46fb7b22b96fe181a3a542bb 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -23,12 +23,18 @@ from itertools import imap from py4j.protocol import Py4JError from py4j.java_collections import MapConverter -from pyspark.rdd import _prepare_for_python_RDD +from pyspark.rdd import RDD, _prepare_for_python_RDD from pyspark.serializers import AutoBatchedSerializer, PickleSerializer -from pyspark.sql.types import StringType, StructType, _infer_type, _verify_type, \ +from pyspark.sql.types import StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter from pyspark.sql.dataframe import DataFrame +try: + import pandas + has_pandas = True +except ImportError: + has_pandas = False + __all__ = ["SQLContext", "HiveContext"] @@ -116,6 +122,31 @@ class SQLContext(object): self._sc._javaAccumulator, returnType.json()) + def _inferSchema(self, rdd, samplingRatio=None): + first = rdd.first() + if not first: + raise ValueError("The first row in RDD is empty, " + "can not infer schema") + if type(first) is dict: + warnings.warn("Using RDD of dict to inferSchema is deprecated," + "please use pyspark.sql.Row instead") + + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio < 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + return schema + def inferSchema(self, rdd, samplingRatio=None): """Infer and apply a schema to an RDD of L{Row}. @@ -171,29 +202,7 @@ class SQLContext(object): if isinstance(rdd, DataFrame): raise TypeError("Cannot apply schema to DataFrame") - first = rdd.first() - if not first: - raise ValueError("The first row in RDD is empty, " - "can not infer schema") - if type(first) is dict: - warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.sql.Row instead") - - if samplingRatio is None: - schema = _infer_schema(first) - if _has_nulltype(schema): - for row in rdd.take(100)[1:]: - schema = _merge_type(schema, _infer_schema(row)) - if not _has_nulltype(schema): - break - else: - warnings.warn("Some of types cannot be determined by the " - "first 100 rows, please try again with sampling") - else: - if samplingRatio < 0.99: - rdd = rdd.sample(False, float(samplingRatio)) - schema = rdd.map(_infer_schema).reduce(_merge_type) - + schema = self._inferSchema(rdd, samplingRatio) converter = _create_converter(schema) rdd = rdd.map(converter) return self.applySchema(rdd, schema) @@ -274,7 +283,7 @@ class SQLContext(object): raise TypeError("Cannot apply schema to DataFrame") if not isinstance(schema, StructType): - raise TypeError("schema should be StructType") + raise TypeError("schema should be StructType, but got %s" % schema) # take the first few rows to verify schema rows = rdd.take(10) @@ -294,9 +303,9 @@ class SQLContext(object): df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return DataFrame(df, self) - def createDataFrame(self, rdd, schema=None, samplingRatio=None): + def createDataFrame(self, data, schema=None, samplingRatio=None): """ - Create a DataFrame from an RDD of tuple/list and an optional `schema`. + Create a DataFrame from an RDD of tuple/list, list or pandas.DataFrame. `schema` could be :class:`StructType` or a list of column names. @@ -311,12 +320,20 @@ class SQLContext(object): rows will be used to do referring. The first row will be used if `samplingRatio` is None. - :param rdd: an RDD of Row or tuple or list or dict + :param data: an RDD of Row/tuple/list/dict, list, or pandas.DataFrame :param schema: a StructType or list of names of columns :param samplingRatio: the sample ratio of rows used for inferring :return: a DataFrame - >>> rdd = sc.parallelize([('Alice', 1)]) + >>> l = [('Alice', 1)] + >>> sqlCtx.createDataFrame(l, ['name', 'age']).collect() + [Row(name=u'Alice', age=1)] + + >>> d = [{'name': 'Alice', 'age': 1}] + >>> sqlCtx.createDataFrame(d).collect() + [Row(age=1, name=u'Alice')] + + >>> rdd = sc.parallelize(l) >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age']) >>> df.collect() [Row(name=u'Alice', age=1)] @@ -336,19 +353,32 @@ class SQLContext(object): >>> df3.collect() [Row(name=u'Alice', age=1)] """ - if isinstance(rdd, DataFrame): - raise TypeError("rdd is already a DataFrame") + if isinstance(data, DataFrame): + raise TypeError("data is already a DataFrame") - if isinstance(schema, StructType): - return self.applySchema(rdd, schema) - else: - if isinstance(schema, (list, tuple)): - first = rdd.first() - if not isinstance(first, (list, tuple)): - raise ValueError("each row in `rdd` should be list or tuple") - row_cls = Row(*schema) - rdd = rdd.map(lambda r: row_cls(*r)) - return self.inferSchema(rdd, samplingRatio) + if has_pandas and isinstance(data, pandas.DataFrame): + data = self._sc.parallelize(data.to_records(index=False)) + if schema is None: + schema = list(data.columns) + + if not isinstance(data, RDD): + try: + # data could be list, tuple, generator ... + data = self._sc.parallelize(data) + except Exception: + raise ValueError("cannot create an RDD from type: %s" % type(data)) + + if schema is None: + return self.inferSchema(data, samplingRatio) + + if isinstance(schema, (list, tuple)): + first = data.first() + if not isinstance(first, (list, tuple)): + raise ValueError("each row in `rdd` should be list or tuple") + row_cls = Row(*schema) + schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio) + + return self.applySchema(data, schema) def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3eef0cc376a2dfdd2b8aa8e97c006eb174e6ce5e..3eb56ed74cc6f30eb00eb6a6cc4dd52eb8c3c9b6 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -236,6 +236,24 @@ class DataFrame(object): """ print (self._jdf.schema().treeString()) + def show(self): + """ + Print the first 20 rows. + + >>> df.show() + age name + 2 Alice + 5 Bob + >>> df + age name + 2 Alice + 5 Bob + """ + print (self) + + def __repr__(self): + return self._jdf.showString() + def count(self): """Return the number of elements in this RDD. @@ -380,9 +398,9 @@ class DataFrame(object): """Return all column names and their data types as a list. >>> df.dtypes - [('age', 'integer'), ('name', 'string')] + [('age', 'int'), ('name', 'string')] """ - return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields] + return [(str(f.name), f.dataType.simpleString()) for f in self.schema().fields] @property def columns(self): @@ -606,6 +624,17 @@ class DataFrame(object): """ return self.select('*', col.alias(colName)) + def renameColumn(self, existing, new): + """ Rename an existing column to a new name + + >>> df.renameColumn('age', 'age2').collect() + [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')] + """ + cols = [Column(_to_java_column(c), self.sql_ctx).alias(new) + if c == existing else c + for c in self.columns] + return self.select(*cols) + def to_pandas(self): """ Collect all the rows and return a `pandas.DataFrame`. @@ -885,6 +914,12 @@ class Column(DataFrame): jc = self._jc.cast(jdt) return Column(jc, self.sql_ctx) + def __repr__(self): + if self._jdf.isComputable(): + return self._jdf.samples() + else: + return 'Column<%s>' % self._jdf.toString() + def to_pandas(self): """ Return a pandas.Series from the column @@ -1030,7 +1065,8 @@ def _test(): globs['df'] = sqlCtx.inferSchema(rdd2) globs['df2'] = sqlCtx.inferSchema(rdd3) (failure_count, test_count) = doctest.testmod( - pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS) + pyspark.sql.dataframe, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() if failure_count: exit(-1) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5e41e36897b5dd9d8f3e99bc583ba8d2cc6c3fd7..43e5c3a1b00fa931cbbf3ffdb97d5177260734fa 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -194,7 +194,7 @@ class SQLTests(ReusedPySparkTestCase): result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) - df2 = self.sqlCtx.createDataFrame(rdd, 1.0) + df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) self.assertEqual(df.schema(), df2.schema()) self.assertEqual({}, df2.map(lambda r: r.d).first()) self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 41afefe48ee5ef2f1a6be22c03f44e987ae23321..40bd7e54a9d7b57ed58ef29c27302599b6460cd0 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -52,6 +52,9 @@ class DataType(object): def typeName(cls): return cls.__name__[:-4].lower() + def simpleString(self): + return self.typeName() + def jsonValue(self): return self.typeName() @@ -145,6 +148,12 @@ class DecimalType(DataType): self.scale = scale self.hasPrecisionInfo = precision is not None + def simpleString(self): + if self.hasPrecisionInfo: + return "decimal(%d,%d)" % (self.precision, self.scale) + else: + return "decimal(10,0)" + def jsonValue(self): if self.hasPrecisionInfo: return "decimal(%d,%d)" % (self.precision, self.scale) @@ -180,6 +189,8 @@ class ByteType(PrimitiveType): The data type representing int values with 1 singed byte. """ + def simpleString(self): + return 'tinyint' class IntegerType(PrimitiveType): @@ -188,6 +199,8 @@ class IntegerType(PrimitiveType): The data type representing int values. """ + def simpleString(self): + return 'int' class LongType(PrimitiveType): @@ -198,6 +211,8 @@ class LongType(PrimitiveType): beyond the range of [-9223372036854775808, 9223372036854775807], please use DecimalType. """ + def simpleString(self): + return 'bigint' class ShortType(PrimitiveType): @@ -206,6 +221,8 @@ class ShortType(PrimitiveType): The data type representing int values with 2 signed bytes. """ + def simpleString(self): + return 'smallint' class ArrayType(DataType): @@ -233,6 +250,9 @@ class ArrayType(DataType): self.elementType = elementType self.containsNull = containsNull + def simpleString(self): + return 'array<%s>' % self.elementType.simpleString() + def __repr__(self): return "ArrayType(%s,%s)" % (self.elementType, str(self.containsNull).lower()) @@ -283,6 +303,9 @@ class MapType(DataType): self.valueType = valueType self.valueContainsNull = valueContainsNull + def simpleString(self): + return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString()) + def __repr__(self): return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, str(self.valueContainsNull).lower()) @@ -337,6 +360,9 @@ class StructField(DataType): self.nullable = nullable self.metadata = metadata or {} + def simpleString(self): + return '%s:%s' % (self.name, self.dataType.simpleString()) + def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, str(self.nullable).lower()) @@ -379,6 +405,9 @@ class StructType(DataType): """ self.fields = fields + def simpleString(self): + return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) + def __repr__(self): return ("StructType(List(%s))" % ",".join(str(field) for field in self.fields)) @@ -435,6 +464,9 @@ class UserDefinedType(DataType): """ raise NotImplementedError("UDT must implement deserialize().") + def simpleString(self): + return 'null' + def json(self): return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index 9638ce0865db0fdd34f14ebab342510eb84ee249..41da4424ae4595fef17bd3f7fa29178f03745e7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -126,7 +126,10 @@ private[sql] class DataFrameImpl protected[sql]( logicalPlan.isInstanceOf[LocalRelation] } - override def show(): Unit = { + /** + * Internal API for Python + */ + private[sql] def showString(): String = { val data = take(20) val numCols = schema.fieldNames.length @@ -146,12 +149,16 @@ private[sql] class DataFrameImpl protected[sql]( } } - // Pad the cells and print them - println(rows.map { row => + // Pad the cells + rows.map { row => row.zipWithIndex.map { case (cell, i) => String.format(s"%-${colWidths(i)}s", cell) }.mkString(" ") - }.mkString("\n")) + }.mkString("\n") + } + + override def show(): Unit = { + println(showString) } override def join(right: DataFrame): DataFrame = {