diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index e55f285a778c408df30b361de3423fc70314854c..6a6dfbc5851b86457aee71b2a3bcbd752c56f2c4 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -2284,6 +2284,18 @@ class DataFrame(object): """ return self.select('*', col.alias(colName)) + def to_pandas(self): + """ + Collect all the rows and return a `pandas.DataFrame`. + + >>> df.to_pandas() # doctest: +SKIP + age name + 0 2 Alice + 1 5 Bob + """ + import pandas as pd + return pd.DataFrame.from_records(self.collect(), columns=self.columns) + # Having SchemaRDD for backward compatibility (for docs) class SchemaRDD(DataFrame): @@ -2551,6 +2563,19 @@ class Column(DataFrame): jc = self._jc.cast(jdt) return Column(jc, self.sql_ctx) + def to_pandas(self): + """ + Return a pandas.Series from the column + + >>> df.age.to_pandas() # doctest: +SKIP + 0 2 + 1 5 + dtype: int64 + """ + import pandas as pd + data = [c for c, in self.collect()] + return pd.Series(data) + def _aggregate_func(name, doc=""): """ Create a function for aggregator by name"""