diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index b7cfc8bd9c542a5ea62aac52e89e20c801fc4a48..acbaba6791850445fec8b1213ef7f8d08dd2213c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -17,8 +17,10 @@ package org.apache.spark.api.python -import java.io.{File, InputStream, IOException, OutputStream} +import java.io.{File} +import java.util.{List => JList} +import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext @@ -44,4 +46,11 @@ private[spark] object PythonUtils { def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = { sc.parallelize(List("a", null, "b")) } + + /** + * Convert list of T into seq of T (for calling API with varargs) + */ + def toSeq[T](cols: JList[T]): Seq[T] = { + cols.toList.toSeq + } } diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 74305dea749c8b6a063756751dc8d6b867d79e61..a266cde51d3172e77121b9577d39f7a6184691c3 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -2128,7 +2128,7 @@ class DataFrame(object): raise ValueError("should sort by at least one column") jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) - jdf = self._jdf.sort(self._sc._jvm.Dsl.toColumns(jcols)) + jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) return DataFrame(jdf, self.sql_ctx) sortBy = sort @@ -2159,13 +2159,20 @@ class DataFrame(object): >>> df['age'].collect() [Row(age=2), Row(age=5)] + >>> df[ ["name", "age"]].collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df[ df.age > 3 ].collect() + [Row(age=5, name=u'Bob')] """ if isinstance(item, basestring): jc = self._jdf.apply(item) return Column(jc, self.sql_ctx) - - # TODO projection - raise IndexError + elif isinstance(item, Column): + return self.filter(item) + elif isinstance(item, list): + return self.select(*item) + else: + raise IndexError("unexpected index: %s" % item) def __getattr__(self, name): """ Return the column by given name @@ -2194,18 +2201,44 @@ class DataFrame(object): cols = ["*"] jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) - jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) + jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + def selectExpr(self, *expr): + """ + Selects a set of SQL expressions. This is a variant of + `select` that accepts SQL expressions. + + >>> df.selectExpr("age * 2", "abs(age)").collect() + [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)] + """ + jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) + jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) return DataFrame(jdf, self.sql_ctx) def filter(self, condition): - """ Filtering rows using the given condition. + """ Filtering rows using the given condition, which could be + Column expression or string of SQL expression. + + where() is an alias for filter(). >>> df.filter(df.age > 3).collect() [Row(age=5, name=u'Bob')] >>> df.where(df.age == 2).collect() [Row(age=2, name=u'Alice')] + + >>> df.filter("age > 3").collect() + [Row(age=5, name=u'Bob')] + >>> df.where("age = 2").collect() + [Row(age=2, name=u'Alice')] """ - return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx) + if isinstance(condition, basestring): + jdf = self._jdf.filter(condition) + elif isinstance(condition, Column): + jdf = self._jdf.filter(condition._jc) + else: + raise TypeError("condition should be string or Column") + return DataFrame(jdf, self.sql_ctx) where = filter @@ -2223,7 +2256,7 @@ class DataFrame(object): """ jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) - jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) + jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) return GroupedDataFrame(jdf, self.sql_ctx) def agg(self, *exprs): @@ -2338,7 +2371,7 @@ class GroupedDataFrame(object): assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" jcols = ListConverter().convert([c._jc for c in exprs[1:]], self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) + jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) return DataFrame(jdf, self.sql_ctx) @dfapi @@ -2633,7 +2666,7 @@ class Dsl(object): jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client) jc = sc._jvm.Dsl.countDistinct(_to_java_column(col), - sc._jvm.Dsl.toColumns(jcols)) + sc._jvm.PythonUtils.toSeq(jcols)) return Column(jc) @staticmethod diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala index 8cf59f0a1f099e14ec7d434a2a57c8e3f313bdde..50f442dd87bf3b36a4981c824116ae0030329acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql -import java.util.{List => JList} - import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} -import scala.collection.JavaConversions._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ @@ -169,14 +166,6 @@ object Dsl { /** Computes the absolutle value. */ def abs(e: Column): Column = Abs(e.expr) - /** - * This is a private API for Python - * TODO: move this to a private package - */ - def toColumns(cols: JList[Column]): Seq[Column] = { - cols.toList.toSeq - } - ////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////