diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ca50fd6f05867173f27e733e26bf558549f67f24..68c9cb0c020188979e5b1f09ac9cd89fb92d57b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -56,7 +56,7 @@ object Column { class Column( sqlContext: Option[SQLContext], plan: Option[LogicalPlan], - val expr: Expression) + protected[sql] val expr: Expression) extends DataFrame(sqlContext, plan) with ExpressionApi { /** Turns a Catalyst expression into a `Column`. */ @@ -437,9 +437,7 @@ class Column( override def rlike(literal: String): Column = RLike(expr, lit(literal).expr) /** - * An expression that gets an - * @param ordinal - * @return + * An expression that gets an item at position `ordinal` out of an array. */ override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal)) @@ -490,11 +488,38 @@ class Column( * {{{ * // Casts colA to IntegerType. * import org.apache.spark.sql.types.IntegerType - * df.select(df("colA").as(IntegerType)) + * df.select(df("colA").cast(IntegerType)) + * + * // equivalent to + * df.select(df("colA").cast("int")) * }}} */ override def cast(to: DataType): Column = Cast(expr, to) + /** + * Casts the column to a different data type, using the canonical string representation + * of the type. The supported types are: `string`, `boolean`, `byte`, `short`, `int`, `long`, + * `float`, `double`, `decimal`, `date`, `timestamp`. + * {{{ + * // Casts colA to integer. + * df.select(df("colA").cast("int")) + * }}} + */ + override def cast(to: String): Column = Cast(expr, to.toLowerCase match { + case "string" => StringType + case "boolean" => BooleanType + case "byte" => ByteType + case "short" => ShortType + case "int" => IntegerType + case "long" => LongType + case "float" => FloatType + case "double" => DoubleType + case "decimal" => DecimalType.Unlimited + case "date" => DateType + case "timestamp" => TimestampType + case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""") + }) + override def desc: Column = SortOrder(expr, Descending) override def asc: Column = SortOrder(expr, Ascending) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 94c13a5c26678f28c31641577009693f8b7586ed..1ff25adcf836ad04aed69fb3786f9a6e0441056b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -208,7 +208,7 @@ class DataFrame protected[sql]( } /** - * Returns a new [[DataFrame]] sorted by the specified column, in ascending column. + * Returns a new [[DataFrame]] sorted by the specified column, all in ascending order. * {{{ * // The following 3 are equivalent * df.sort("sortcol") @@ -216,8 +216,9 @@ class DataFrame protected[sql]( * df.sort($"sortcol".asc) * }}} */ - override def sort(colName: String): DataFrame = { - Sort(Seq(SortOrder(apply(colName).expr, Ascending)), global = true, logicalPlan) + @scala.annotation.varargs + override def sort(sortCol: String, sortCols: String*): DataFrame = { + orderBy(apply(sortCol), sortCols.map(apply) :_*) } /** @@ -239,6 +240,15 @@ class DataFrame protected[sql]( Sort(sortOrder, global = true, logicalPlan) } + /** + * Returns a new [[DataFrame]] sorted by the given expressions. + * This is an alias of the `sort` function. + */ + @scala.annotation.varargs + override def orderBy(sortCol: String, sortCols: String*): DataFrame = { + sort(sortCol, sortCols :_*) + } + /** * Returns a new [[DataFrame]] sorted by the given expressions. * This is an alias of the `sort` function. @@ -401,6 +411,16 @@ class DataFrame protected[sql]( */ override def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) + /** + * Aggregates on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg(Map("age" -> "max", "salary" -> "avg")) + * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }} + */ + override def agg(exprs: java.util.Map[String, String]): DataFrame = agg(exprs.toMap) + /** * Aggregates on the entire [[DataFrame]] without groups. * {{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala index f47ff995e919b9f2fc5d2433bec2f19a86f7e545..75717e7cd842c506b08414c5fb928a8609ebaf11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala @@ -62,6 +62,11 @@ object Dsl { */ def col(colName: String): Column = new Column(colName) + /** + * Returns a [[Column]] based on the given column name. Alias of [[col]]. + */ + def column(colName: String): Column = new Column(colName) + /** * Creates a [[Column]] of literal value. */ @@ -96,6 +101,7 @@ object Dsl { def sumDistinct(e: Column): Column = SumDistinct(e.expr) def count(e: Column): Column = Count(e.expr) + @scala.annotation.varargs def countDistinct(expr: Column, exprs: Column*): Column = CountDistinct((expr +: exprs).map(_.expr)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala index 1f1e9bd9899f665a5a3010f302de432873f6e053..1c948cbbfe58f09587451e22b94455140ee47ab8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala @@ -58,7 +58,9 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi } /** - * Compute aggregates by specifying a map from column name to aggregate methods. + * Compute aggregates by specifying a map from column name to aggregate methods. The resulting + * [[DataFrame]] will also contain the grouping columns. + * * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department @@ -76,7 +78,9 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi } /** - * Compute aggregates by specifying a map from column name to aggregate methods. + * Compute aggregates by specifying a map from column name to aggregate methods. The resulting + * [[DataFrame]] will also contain the grouping columns. + * * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department @@ -91,12 +95,15 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi } /** - * Compute aggregates by specifying a series of aggregate columns. - * The available aggregate methods are defined in [[org.apache.spark.sql.dsl]]. + * Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this + * class, the resulting [[DataFrame]] won't automatically include the grouping columns. + * + * The available aggregate methods are defined in [[org.apache.spark.sql.Dsl]]. + * * {{{ * // Selects the age of the oldest employee and the aggregate expense for each department * import org.apache.spark.sql.dsl._ - * df.groupBy("department").agg(max($"age"), sum($"expense")) + * df.groupBy("department").agg($"department", max($"age"), sum($"expense")) * }}} */ @scala.annotation.varargs @@ -109,31 +116,39 @@ class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expressi new DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) } - /** Count the number of rows for each group. */ + /** + * Count the number of rows for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ override def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")()) /** * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + * The resulting [[DataFrame]] will also contain the grouping columns. */ override def mean(): DataFrame = aggregateNumericColumns(Average) /** * Compute the max value for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. */ override def max(): DataFrame = aggregateNumericColumns(Max) /** * Compute the mean value for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. */ override def avg(): DataFrame = aggregateNumericColumns(Average) /** * Compute the min value for each numeric column for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. */ override def min(): DataFrame = aggregateNumericColumns(Min) /** * Compute the sum for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. */ override def sum(): DataFrame = aggregateNumericColumns(Sum) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala index 59634082f61c29c578dbdda112c94b997a4d2dfd..eb0eb3f32560ccba7a41f792c9e8435509b00197 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala @@ -113,16 +113,22 @@ private[sql] trait DataFrameSpecificApi { def agg(exprs: Map[String, String]): DataFrame + def agg(exprs: java.util.Map[String, String]): DataFrame + @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame - def sort(colName: String): DataFrame + @scala.annotation.varargs + def sort(sortExpr: Column, sortExprs: Column*): DataFrame + + @scala.annotation.varargs + def sort(sortCol: String, sortCols: String*): DataFrame @scala.annotation.varargs def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame @scala.annotation.varargs - def sort(sortExpr: Column, sortExprs: Column*): DataFrame + def orderBy(sortCol: String, sortCols: String*): DataFrame def join(right: DataFrame): DataFrame @@ -257,6 +263,7 @@ private[sql] trait ExpressionApi { def getField(fieldName: String): Column def cast(to: DataType): Column + def cast(to: String): Column def asc: Column def desc: Column diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java new file mode 100644 index 0000000000000000000000000000000000000000..639436368c4a35d4be843abe0ee89bb5589cc28f --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import com.google.common.collect.ImmutableMap; + +import org.apache.spark.sql.Column; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.types.DataTypes; + +import static org.apache.spark.sql.Dsl.*; + +/** + * This test doesn't actually run anything. It is here to check the API compatibility for Java. + */ +public class JavaDsl { + + public static void testDataFrame(final DataFrame df) { + DataFrame df1 = df.select("colA"); + df1 = df.select("colA", "colB"); + + df1 = df.select(col("colA"), col("colB"), lit("literal value").$plus(1)); + + df1 = df.filter(col("colA")); + + java.util.Map<String, String> aggExprs = ImmutableMap.<String, String>builder() + .put("colA", "sum") + .put("colB", "avg") + .build(); + + df1 = df.agg(aggExprs); + + df1 = df.groupBy("groupCol").agg(aggExprs); + + df1 = df.join(df1, col("key1").$eq$eq$eq(col("key2")), "outer"); + + df.orderBy("colA"); + df.orderBy("colA", "colB", "colC"); + df.orderBy(col("colA").desc()); + df.orderBy(col("colA").desc(), col("colB").asc()); + + df.sort("colA"); + df.sort("colA", "colB", "colC"); + df.sort(col("colA").desc()); + df.sort(col("colA").desc(), col("colB").asc()); + + df.as("b"); + + df.limit(5); + + df.unionAll(df1); + df.intersect(df1); + df.except(df1); + + df.sample(true, 0.1, 234); + + df.head(); + df.head(5); + df.first(); + df.count(); + } + + public static void testColumn(final Column c) { + c.asc(); + c.desc(); + + c.endsWith("abcd"); + c.startsWith("afgasdf"); + + c.like("asdf%"); + c.rlike("wef%asdf"); + + c.as("newcol"); + + c.cast("int"); + c.cast(DataTypes.IntegerType); + } + + public static void testDsl() { + // Creating a column. + Column c = col("abcd"); + Column c1 = column("abcd"); + + // Literals + Column l1 = lit(1); + Column l2 = lit(1.0); + Column l3 = lit("abcd"); + + // Functions + Column a = upper(c); + a = lower(c); + a = sqrt(c); + a = abs(c); + + // Aggregates + a = min(c); + a = max(c); + a = sum(c); + a = sumDistinct(c); + a = countDistinct(c, a); + a = avg(c); + a = first(c); + a = last(c); + } +}