From 37537760d19eab878a5e1a48641cc49e6cb4b989 Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@databricks.com> Date: Fri, 1 May 2015 12:49:02 -0700 Subject: [PATCH] [SPARK-7274] [SQL] Create Column expression for array/struct creation. Author: Reynold Xin <rxin@databricks.com> Closes #5802 from rxin/SPARK-7274 and squashes the following commits: 19aecaa [Reynold Xin] Fixed unicode tests. bfc1538 [Reynold Xin] Export all Python functions. 2517b8c [Reynold Xin] Code review. 23da335 [Reynold Xin] Fixed Python bug. 132002e [Reynold Xin] Fixed tests. 56fce26 [Reynold Xin] Added Python support. b0d591a [Reynold Xin] Fixed debug error. 86926a6 [Reynold Xin] Added test suite. 7dbb9ab [Reynold Xin] Ok one more. 470e2f5 [Reynold Xin] One more MLlib ... e2d14f0 [Reynold Xin] [SPARK-7274][SQL] Create Column expression for array/struct creation. --- .../spark/ml/feature/VectorAssembler.scala | 13 ++- python/pyspark/sql/functions.py | 80 +++++++++++++----- .../catalyst/expressions/BoundAttribute.scala | 10 ++- .../org/apache/spark/sql/functions.scala | 41 ++++++++- .../spark/sql/DataFrameFunctionsSuite.scala | 84 +++++++++++++++++++ 5 files changed, 199 insertions(+), 29 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 7b2a451ca5..5e781a326d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -25,9 +25,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.sql.{Column, DataFrame, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -53,13 +51,12 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { val inputColNames = map(inputCols) val args = inputColNames.map { c => schema(c).dataType match { - case DoubleType => UnresolvedAttribute(c) - case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) - case _: NumericType | BooleanType => - Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() + case DoubleType => dataset(c) + case _: VectorUDT => dataset(c) + case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") } } - dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol))) + dataset.select(col("*"), assembleFunc(struct(args : _*)).as(map(outputCol))) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 241f821757..641220a264 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -24,13 +24,20 @@ if sys.version < "3": from itertools import imap as map from pyspark import SparkContext -from pyspark.rdd import _prepare_for_python_RDD +from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.dataframe import Column, _to_java_column, _to_seq -__all__ = ['countDistinct', 'approxCountDistinct', 'udf'] +__all__ = [ + 'approxCountDistinct', + 'countDistinct', + 'monotonicallyIncreasingId', + 'rand', + 'randn', + 'sparkPartitionId', + 'udf'] def _create_function(name, doc=""): @@ -74,27 +81,21 @@ __all__ += _functions.keys() __all__.sort() -def rand(seed=None): - """ - Generate a random column with i.i.d. samples from U[0.0, 1.0]. - """ - sc = SparkContext._active_spark_context - if seed: - jc = sc._jvm.functions.rand(seed) - else: - jc = sc._jvm.functions.rand() - return Column(jc) +def array(*cols): + """Creates a new array column. + :param cols: list of column names (string) or list of :class:`Column` expressions that have + the same data type. -def randn(seed=None): - """ - Generate a column with i.i.d. samples from the standard normal distribution. + >>> df.select(array('age', 'age').alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] + >>> df.select(array([df.age, df.age]).alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] """ sc = SparkContext._active_spark_context - if seed: - jc = sc._jvm.functions.randn(seed) - else: - jc = sc._jvm.functions.randn() + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column)) return Column(jc) @@ -146,6 +147,28 @@ def monotonicallyIncreasingId(): return Column(sc._jvm.functions.monotonicallyIncreasingId()) +def rand(seed=None): + """Generates a random column with i.i.d. samples from U[0.0, 1.0]. + """ + sc = SparkContext._active_spark_context + if seed: + jc = sc._jvm.functions.rand(seed) + else: + jc = sc._jvm.functions.rand() + return Column(jc) + + +def randn(seed=None): + """Generates a column with i.i.d. samples from the standard normal distribution. + """ + sc = SparkContext._active_spark_context + if seed: + jc = sc._jvm.functions.randn(seed) + else: + jc = sc._jvm.functions.randn() + return Column(jc) + + def sparkPartitionId(): """A column for partition ID of the Spark task. @@ -158,6 +181,25 @@ def sparkPartitionId(): return Column(sc._jvm.functions.sparkPartitionId()) +@ignore_unicode_prefix +def struct(*cols): + """Creates a new struct column. + + :param cols: list of column names (string) or list of :class:`Column` expressions + that are named or aliased. + + >>> df.select(struct('age', 'name').alias("struct")).collect() + [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] + >>> df.select(struct([df.age, df.name]).alias("struct")).collect() + [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.struct(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + class UserDefinedFunction(object): """ User defined function in Python diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 2225621dba..c6217f07c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -28,13 +28,21 @@ import org.apache.spark.sql.catalyst.trees * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) - extends Expression with trees.LeafNode[Expression] { + extends NamedExpression with trees.LeafNode[Expression] { type EvaluatedType = Any override def toString: String = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) + + override def name: String = s"i[$ordinal]" + + override def toAttribute: Attribute = throw new UnsupportedOperationException + + override def qualifiers: Seq[String] = throw new UnsupportedOperationException + + override def exprId: ExprId = throw new UnsupportedOperationException } object BindReferences extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 242e64d3ff..7e283393d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -283,6 +283,23 @@ object functions { */ def abs(e: Column): Column = Abs(e.expr) + /** + * Creates a new array column. The input columns must all have the same data type. + * + * @group normal_funcs + */ + @scala.annotation.varargs + def array(cols: Column*): Column = CreateArray(cols.map(_.expr)) + + /** + * Creates a new array column. The input columns must all have the same data type. + * + * @group normal_funcs + */ + def array(colName: String, colNames: String*): Column = { + array((colName +: colNames).map(col) : _*) + } + /** * Returns the first column that is not null. * {{{ @@ -390,6 +407,28 @@ object functions { */ def sqrt(e: Column): Column = Sqrt(e.expr) + /** + * Creates a new struct column. The input column must be a column in a [[DataFrame]], or + * a derived column expression that is named (i.e. aliased). + * + * @group normal_funcs + */ + @scala.annotation.varargs + def struct(cols: Column*): Column = { + require(cols.forall(_.expr.isInstanceOf[NamedExpression]), + s"struct input columns must all be named or aliased ($cols)") + CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression])) + } + + /** + * Creates a new struct column that composes multiple input columns. + * + * @group normal_funcs + */ + def struct(colName: String, colNames: String*): Column = { + struct((colName +: colNames).map(col) : _*) + } + /** * Converts a string expression to upper case. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala new file mode 100644 index 0000000000..ca03713ef4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.types._ + +/** + * Test suite for functions in [[org.apache.spark.sql.functions]]. + */ +class DataFrameFunctionsSuite extends QueryTest { + + test("array with column name") { + val df = Seq((0, 1)).toDF("a", "b") + val row = df.select(array("a", "b")).first() + + val expectedType = ArrayType(IntegerType, containsNull = false) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Seq[Int]](0) === Seq(0, 1)) + } + + test("array with column expression") { + val df = Seq((0, 1)).toDF("a", "b") + val row = df.select(array(col("a"), col("b") + col("b"))).first() + + val expectedType = ArrayType(IntegerType, containsNull = false) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Seq[Int]](0) === Seq(0, 2)) + } + + // Turn this on once we add a rule to the analyzer to throw a friendly exception + ignore("array: throw exception if putting columns of different types into an array") { + val df = Seq((0, "str")).toDF("a", "b") + intercept[AnalysisException] { + df.select(array("a", "b")) + } + } + + test("struct with column name") { + val df = Seq((1, "str")).toDF("a", "b") + val row = df.select(struct("a", "b")).first() + + val expectedType = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Row](0) === Row(1, "str")) + } + + test("struct with column expression") { + val df = Seq((1, "str")).toDF("a", "b") + val row = df.select(struct((col("a") * 2).as("c"), col("b"))).first() + + val expectedType = StructType(Seq( + StructField("c", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Row](0) === Row(2, "str")) + } + + test("struct: must use named column expression") { + intercept[IllegalArgumentException] { + struct(col("a") * 2) + } + } +} -- GitLab