Skip to content
Snippets Groups Projects
Commit 37537760 authored by Reynold Xin's avatar Reynold Xin Committed by Xiangrui Meng
Browse files

[SPARK-7274] [SQL] Create Column expression for array/struct creation.

Author: Reynold Xin <rxin@databricks.com>

Closes #5802 from rxin/SPARK-7274 and squashes the following commits:

19aecaa [Reynold Xin] Fixed unicode tests.
bfc1538 [Reynold Xin] Export all Python functions.
2517b8c [Reynold Xin] Code review.
23da335 [Reynold Xin] Fixed Python bug.
132002e [Reynold Xin] Fixed tests.
56fce26 [Reynold Xin] Added Python support.
b0d591a [Reynold Xin] Fixed debug error.
86926a6 [Reynold Xin] Added test suite.
7dbb9ab [Reynold Xin] Ok one more.
470e2f5 [Reynold Xin] One more MLlib ...
e2d14f0 [Reynold Xin] [SPARK-7274][SQL] Create Column expression for array/struct creation.
parent 16860327
No related branches found
No related tags found
No related merge requests found
......@@ -25,9 +25,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
......@@ -53,13 +51,12 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
val inputColNames = map(inputCols)
val args = inputColNames.map { c =>
schema(c).dataType match {
case DoubleType => UnresolvedAttribute(c)
case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c)
case _: NumericType | BooleanType =>
Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")()
case DoubleType => dataset(c)
case _: VectorUDT => dataset(c)
case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
}
}
dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol)))
dataset.select(col("*"), assembleFunc(struct(args : _*)).as(map(outputCol)))
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
......
......@@ -24,13 +24,20 @@ if sys.version < "3":
from itertools import imap as map
from pyspark import SparkContext
from pyspark.rdd import _prepare_for_python_RDD
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType
from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
__all__ = ['countDistinct', 'approxCountDistinct', 'udf']
__all__ = [
'approxCountDistinct',
'countDistinct',
'monotonicallyIncreasingId',
'rand',
'randn',
'sparkPartitionId',
'udf']
def _create_function(name, doc=""):
......@@ -74,27 +81,21 @@ __all__ += _functions.keys()
__all__.sort()
def rand(seed=None):
"""
Generate a random column with i.i.d. samples from U[0.0, 1.0].
"""
sc = SparkContext._active_spark_context
if seed:
jc = sc._jvm.functions.rand(seed)
else:
jc = sc._jvm.functions.rand()
return Column(jc)
def array(*cols):
"""Creates a new array column.
:param cols: list of column names (string) or list of :class:`Column` expressions that have
the same data type.
def randn(seed=None):
"""
Generate a column with i.i.d. samples from the standard normal distribution.
>>> df.select(array('age', 'age').alias("arr")).collect()
[Row(arr=[2, 2]), Row(arr=[5, 5])]
>>> df.select(array([df.age, df.age]).alias("arr")).collect()
[Row(arr=[2, 2]), Row(arr=[5, 5])]
"""
sc = SparkContext._active_spark_context
if seed:
jc = sc._jvm.functions.randn(seed)
else:
jc = sc._jvm.functions.randn()
if len(cols) == 1 and isinstance(cols[0], (list, set)):
cols = cols[0]
jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column))
return Column(jc)
......@@ -146,6 +147,28 @@ def monotonicallyIncreasingId():
return Column(sc._jvm.functions.monotonicallyIncreasingId())
def rand(seed=None):
"""Generates a random column with i.i.d. samples from U[0.0, 1.0].
"""
sc = SparkContext._active_spark_context
if seed:
jc = sc._jvm.functions.rand(seed)
else:
jc = sc._jvm.functions.rand()
return Column(jc)
def randn(seed=None):
"""Generates a column with i.i.d. samples from the standard normal distribution.
"""
sc = SparkContext._active_spark_context
if seed:
jc = sc._jvm.functions.randn(seed)
else:
jc = sc._jvm.functions.randn()
return Column(jc)
def sparkPartitionId():
"""A column for partition ID of the Spark task.
......@@ -158,6 +181,25 @@ def sparkPartitionId():
return Column(sc._jvm.functions.sparkPartitionId())
@ignore_unicode_prefix
def struct(*cols):
"""Creates a new struct column.
:param cols: list of column names (string) or list of :class:`Column` expressions
that are named or aliased.
>>> df.select(struct('age', 'name').alias("struct")).collect()
[Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))]
>>> df.select(struct([df.age, df.name]).alias("struct")).collect()
[Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))]
"""
sc = SparkContext._active_spark_context
if len(cols) == 1 and isinstance(cols[0], (list, set)):
cols = cols[0]
jc = sc._jvm.functions.struct(_to_seq(sc, cols, _to_java_column))
return Column(jc)
class UserDefinedFunction(object):
"""
User defined function in Python
......
......@@ -28,13 +28,21 @@ import org.apache.spark.sql.catalyst.trees
* the layout of intermediate tuples, BindReferences should be run after all such transformations.
*/
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
extends Expression with trees.LeafNode[Expression] {
extends NamedExpression with trees.LeafNode[Expression] {
type EvaluatedType = Any
override def toString: String = s"input[$ordinal]"
override def eval(input: Row): Any = input(ordinal)
override def name: String = s"i[$ordinal]"
override def toAttribute: Attribute = throw new UnsupportedOperationException
override def qualifiers: Seq[String] = throw new UnsupportedOperationException
override def exprId: ExprId = throw new UnsupportedOperationException
}
object BindReferences extends Logging {
......
......@@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag}
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
......@@ -283,6 +283,23 @@ object functions {
*/
def abs(e: Column): Column = Abs(e.expr)
/**
* Creates a new array column. The input columns must all have the same data type.
*
* @group normal_funcs
*/
@scala.annotation.varargs
def array(cols: Column*): Column = CreateArray(cols.map(_.expr))
/**
* Creates a new array column. The input columns must all have the same data type.
*
* @group normal_funcs
*/
def array(colName: String, colNames: String*): Column = {
array((colName +: colNames).map(col) : _*)
}
/**
* Returns the first column that is not null.
* {{{
......@@ -390,6 +407,28 @@ object functions {
*/
def sqrt(e: Column): Column = Sqrt(e.expr)
/**
* Creates a new struct column. The input column must be a column in a [[DataFrame]], or
* a derived column expression that is named (i.e. aliased).
*
* @group normal_funcs
*/
@scala.annotation.varargs
def struct(cols: Column*): Column = {
require(cols.forall(_.expr.isInstanceOf[NamedExpression]),
s"struct input columns must all be named or aliased ($cols)")
CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression]))
}
/**
* Creates a new struct column that composes multiple input columns.
*
* @group normal_funcs
*/
def struct(colName: String, colNames: String*): Column = {
struct((colName +: colNames).map(col) : _*)
}
/**
* Converts a string expression to upper case.
*
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
/**
* Test suite for functions in [[org.apache.spark.sql.functions]].
*/
class DataFrameFunctionsSuite extends QueryTest {
test("array with column name") {
val df = Seq((0, 1)).toDF("a", "b")
val row = df.select(array("a", "b")).first()
val expectedType = ArrayType(IntegerType, containsNull = false)
assert(row.schema(0).dataType === expectedType)
assert(row.getAs[Seq[Int]](0) === Seq(0, 1))
}
test("array with column expression") {
val df = Seq((0, 1)).toDF("a", "b")
val row = df.select(array(col("a"), col("b") + col("b"))).first()
val expectedType = ArrayType(IntegerType, containsNull = false)
assert(row.schema(0).dataType === expectedType)
assert(row.getAs[Seq[Int]](0) === Seq(0, 2))
}
// Turn this on once we add a rule to the analyzer to throw a friendly exception
ignore("array: throw exception if putting columns of different types into an array") {
val df = Seq((0, "str")).toDF("a", "b")
intercept[AnalysisException] {
df.select(array("a", "b"))
}
}
test("struct with column name") {
val df = Seq((1, "str")).toDF("a", "b")
val row = df.select(struct("a", "b")).first()
val expectedType = StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("b", StringType)
))
assert(row.schema(0).dataType === expectedType)
assert(row.getAs[Row](0) === Row(1, "str"))
}
test("struct with column expression") {
val df = Seq((1, "str")).toDF("a", "b")
val row = df.select(struct((col("a") * 2).as("c"), col("b"))).first()
val expectedType = StructType(Seq(
StructField("c", IntegerType, nullable = false),
StructField("b", StringType)
))
assert(row.schema(0).dataType === expectedType)
assert(row.getAs[Row](0) === Row(2, "str"))
}
test("struct: must use named column expression") {
intercept[IllegalArgumentException] {
struct(col("a") * 2)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment