diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 7b2a451ca5ee595efaa8d00bcc1515d7fc08573c..5e781a326d98cf374afdd1b8f0fbc2bd4ca9ccfc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -25,9 +25,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.sql.{Column, DataFrame, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, CreateStruct} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -53,13 +51,12 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { val inputColNames = map(inputCols) val args = inputColNames.map { c => schema(c).dataType match { - case DoubleType => UnresolvedAttribute(c) - case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c) - case _: NumericType | BooleanType => - Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")() + case DoubleType => dataset(c) + case _: VectorUDT => dataset(c) + case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") } } - dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol))) + dataset.select(col("*"), assembleFunc(struct(args : _*)).as(map(outputCol))) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 241f82175726fada6fe66a765faf871c626fcaa3..641220a2642952841bf0ee94a180740d405523ad 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -24,13 +24,20 @@ if sys.version < "3": from itertools import imap as map from pyspark import SparkContext -from pyspark.rdd import _prepare_for_python_RDD +from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.dataframe import Column, _to_java_column, _to_seq -__all__ = ['countDistinct', 'approxCountDistinct', 'udf'] +__all__ = [ + 'approxCountDistinct', + 'countDistinct', + 'monotonicallyIncreasingId', + 'rand', + 'randn', + 'sparkPartitionId', + 'udf'] def _create_function(name, doc=""): @@ -74,27 +81,21 @@ __all__ += _functions.keys() __all__.sort() -def rand(seed=None): - """ - Generate a random column with i.i.d. samples from U[0.0, 1.0]. - """ - sc = SparkContext._active_spark_context - if seed: - jc = sc._jvm.functions.rand(seed) - else: - jc = sc._jvm.functions.rand() - return Column(jc) +def array(*cols): + """Creates a new array column. + :param cols: list of column names (string) or list of :class:`Column` expressions that have + the same data type. -def randn(seed=None): - """ - Generate a column with i.i.d. samples from the standard normal distribution. + >>> df.select(array('age', 'age').alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] + >>> df.select(array([df.age, df.age]).alias("arr")).collect() + [Row(arr=[2, 2]), Row(arr=[5, 5])] """ sc = SparkContext._active_spark_context - if seed: - jc = sc._jvm.functions.randn(seed) - else: - jc = sc._jvm.functions.randn() + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.array(_to_seq(sc, cols, _to_java_column)) return Column(jc) @@ -146,6 +147,28 @@ def monotonicallyIncreasingId(): return Column(sc._jvm.functions.monotonicallyIncreasingId()) +def rand(seed=None): + """Generates a random column with i.i.d. samples from U[0.0, 1.0]. + """ + sc = SparkContext._active_spark_context + if seed: + jc = sc._jvm.functions.rand(seed) + else: + jc = sc._jvm.functions.rand() + return Column(jc) + + +def randn(seed=None): + """Generates a column with i.i.d. samples from the standard normal distribution. + """ + sc = SparkContext._active_spark_context + if seed: + jc = sc._jvm.functions.randn(seed) + else: + jc = sc._jvm.functions.randn() + return Column(jc) + + def sparkPartitionId(): """A column for partition ID of the Spark task. @@ -158,6 +181,25 @@ def sparkPartitionId(): return Column(sc._jvm.functions.sparkPartitionId()) +@ignore_unicode_prefix +def struct(*cols): + """Creates a new struct column. + + :param cols: list of column names (string) or list of :class:`Column` expressions + that are named or aliased. + + >>> df.select(struct('age', 'name').alias("struct")).collect() + [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] + >>> df.select(struct([df.age, df.name]).alias("struct")).collect() + [Row(struct=Row(age=2, name=u'Alice')), Row(struct=Row(age=5, name=u'Bob'))] + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.struct(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + class UserDefinedFunction(object): """ User defined function in Python diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 2225621dbaabd2ab500b20d06fa5032713e3342b..c6217f07c452df6e7bb7eb2ac57249d3f49d78c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -28,13 +28,21 @@ import org.apache.spark.sql.catalyst.trees * the layout of intermediate tuples, BindReferences should be run after all such transformations. */ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) - extends Expression with trees.LeafNode[Expression] { + extends NamedExpression with trees.LeafNode[Expression] { type EvaluatedType = Any override def toString: String = s"input[$ordinal]" override def eval(input: Row): Any = input(ordinal) + + override def name: String = s"i[$ordinal]" + + override def toAttribute: Attribute = throw new UnsupportedOperationException + + override def qualifiers: Seq[String] = throw new UnsupportedOperationException + + override def exprId: ExprId = throw new UnsupportedOperationException } object BindReferences extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 242e64d3ff881f642a5bdfea2c7b513b55f10105..7e283393d0563f2b52eaf86e7ae7cc0be3854d49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -283,6 +283,23 @@ object functions { */ def abs(e: Column): Column = Abs(e.expr) + /** + * Creates a new array column. The input columns must all have the same data type. + * + * @group normal_funcs + */ + @scala.annotation.varargs + def array(cols: Column*): Column = CreateArray(cols.map(_.expr)) + + /** + * Creates a new array column. The input columns must all have the same data type. + * + * @group normal_funcs + */ + def array(colName: String, colNames: String*): Column = { + array((colName +: colNames).map(col) : _*) + } + /** * Returns the first column that is not null. * {{{ @@ -390,6 +407,28 @@ object functions { */ def sqrt(e: Column): Column = Sqrt(e.expr) + /** + * Creates a new struct column. The input column must be a column in a [[DataFrame]], or + * a derived column expression that is named (i.e. aliased). + * + * @group normal_funcs + */ + @scala.annotation.varargs + def struct(cols: Column*): Column = { + require(cols.forall(_.expr.isInstanceOf[NamedExpression]), + s"struct input columns must all be named or aliased ($cols)") + CreateStruct(cols.map(_.expr.asInstanceOf[NamedExpression])) + } + + /** + * Creates a new struct column that composes multiple input columns. + * + * @group normal_funcs + */ + def struct(colName: String, colNames: String*): Column = { + struct((colName +: colNames).map(col) : _*) + } + /** * Converts a string expression to upper case. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..ca03713ef465821e5c723c091b4a382b1a2f3d41 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.types._ + +/** + * Test suite for functions in [[org.apache.spark.sql.functions]]. + */ +class DataFrameFunctionsSuite extends QueryTest { + + test("array with column name") { + val df = Seq((0, 1)).toDF("a", "b") + val row = df.select(array("a", "b")).first() + + val expectedType = ArrayType(IntegerType, containsNull = false) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Seq[Int]](0) === Seq(0, 1)) + } + + test("array with column expression") { + val df = Seq((0, 1)).toDF("a", "b") + val row = df.select(array(col("a"), col("b") + col("b"))).first() + + val expectedType = ArrayType(IntegerType, containsNull = false) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Seq[Int]](0) === Seq(0, 2)) + } + + // Turn this on once we add a rule to the analyzer to throw a friendly exception + ignore("array: throw exception if putting columns of different types into an array") { + val df = Seq((0, "str")).toDF("a", "b") + intercept[AnalysisException] { + df.select(array("a", "b")) + } + } + + test("struct with column name") { + val df = Seq((1, "str")).toDF("a", "b") + val row = df.select(struct("a", "b")).first() + + val expectedType = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Row](0) === Row(1, "str")) + } + + test("struct with column expression") { + val df = Seq((1, "str")).toDF("a", "b") + val row = df.select(struct((col("a") * 2).as("c"), col("b"))).first() + + val expectedType = StructType(Seq( + StructField("c", IntegerType, nullable = false), + StructField("b", StringType) + )) + assert(row.schema(0).dataType === expectedType) + assert(row.getAs[Row](0) === Row(2, "str")) + } + + test("struct: must use named column expression") { + intercept[IllegalArgumentException] { + struct(col("a") * 2) + } + } +}