Skip to content
Snippets Groups Projects
Commit ca55dc95 authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-7152][SQL] Add a Column expression for partition ID.

Author: Reynold Xin <rxin@databricks.com>

Closes #5705 from rxin/df-pid and squashes the following commits:

401018f [Reynold Xin] [SPARK-7152][SQL] Add a Column expression for partition ID.
parent 9a5bbe05
No related branches found
No related tags found
No related merge requests found
......@@ -75,6 +75,20 @@ __all__ += _functions.keys()
__all__.sort()
def approxCountDistinct(col, rsd=None):
"""Returns a new :class:`Column` for approximate distinct count of ``col``.
>>> df.agg(approxCountDistinct(df.age).alias('c')).collect()
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
if rsd is None:
jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col))
else:
jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd)
return Column(jc)
def countDistinct(col, *cols):
"""Returns a new :class:`Column` for distinct count of ``col`` or ``cols``.
......@@ -89,18 +103,16 @@ def countDistinct(col, *cols):
return Column(jc)
def approxCountDistinct(col, rsd=None):
"""Returns a new :class:`Column` for approximate distinct count of ``col``.
def sparkPartitionId():
"""Returns a column for partition ID of the Spark task.
>>> df.agg(approxCountDistinct(df.age).alias('c')).collect()
[Row(c=2)]
Note that this is indeterministic because it depends on data partitioning and task scheduling.
>>> df.repartition(1).select(sparkPartitionId().alias("pid")).collect()
[Row(pid=0), Row(pid=0)]
"""
sc = SparkContext._active_spark_context
if rsd is None:
jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col))
else:
jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd)
return Column(jc)
return Column(sc._jvm.functions.sparkPartitionId())
class UserDefinedFunction(object):
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.expressions
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.expressions.{Row, Expression}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.types.{IntegerType, DataType}
/**
* Expression that returns the current partition id of the Spark task.
*/
case object SparkPartitionID extends Expression with trees.LeafNode[Expression] {
self: Product =>
override type EvaluatedType = Int
override def nullable: Boolean = false
override def dataType: DataType = IntegerType
override def eval(input: Row): Int = TaskContext.get().partitionId()
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
/**
* Package containing expressions that are specific to Spark runtime.
*/
package object expressions
......@@ -276,6 +276,13 @@ object functions {
// Non-aggregate functions
//////////////////////////////////////////////////////////////////////////////////////////////
/**
* Computes the absolute value.
*
* @group normal_funcs
*/
def abs(e: Column): Column = Abs(e.expr)
/**
* Returns the first column that is not null.
* {{{
......@@ -287,6 +294,13 @@ object functions {
@scala.annotation.varargs
def coalesce(e: Column*): Column = Coalesce(e.map(_.expr))
/**
* Converts a string exprsesion to lower case.
*
* @group normal_funcs
*/
def lower(e: Column): Column = Lower(e.expr)
/**
* Unary minus, i.e. negate the expression.
* {{{
......@@ -317,18 +331,13 @@ object functions {
def not(e: Column): Column = !e
/**
* Converts a string expression to upper case.
* Partition ID of the Spark task.
*
* @group normal_funcs
*/
def upper(e: Column): Column = Upper(e.expr)
/**
* Converts a string exprsesion to lower case.
* Note that this is indeterministic because it depends on data partitioning and task scheduling.
*
* @group normal_funcs
*/
def lower(e: Column): Column = Lower(e.expr)
def sparkPartitionId(): Column = execution.expressions.SparkPartitionID
/**
* Computes the square root of the specified float value.
......@@ -338,11 +347,11 @@ object functions {
def sqrt(e: Column): Column = Sqrt(e.expr)
/**
* Computes the absolutle value.
* Converts a string expression to upper case.
*
* @group normal_funcs
*/
def abs(e: Column): Column = Abs(e.expr)
def upper(e: Column): Column = Upper(e.expr)
//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -310,6 +310,14 @@ class ColumnExpressionSuite extends QueryTest {
)
}
test("sparkPartitionId") {
val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b")
checkAnswer(
df.select(sparkPartitionId()),
Row(0)
)
}
test("lift alias out of cast") {
compareExpressions(
col("1234").as("name").cast("int").expr,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment