From 43b15e01c46ea1971569f74c9201a55de39e8917 Mon Sep 17 00:00:00 2001 From: Wenchen Fan <wenchen@databricks.com> Date: Fri, 25 Mar 2016 09:50:06 -0700 Subject: [PATCH] [SPARK-14061][SQL] implement CreateMap ## What changes were proposed in this pull request? As we have `CreateArray` and `CreateStruct`, we should also have `CreateMap`. This PR adds the `CreateMap` expression, and the DataFrame API, and python API. ## How was this patch tested? various new tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #11879 from cloud-fan/create_map. --- python/pyspark/sql/functions.py | 20 +++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../catalyst/analysis/HiveTypeCoercion.scala | 35 +++++++- .../expressions/complexTypeCreator.scala | 83 ++++++++++++++++++- .../sql/catalyst/util/ArrayBasedMapData.scala | 5 +- .../ExpressionTypeCheckingSuite.scala | 16 +++- .../analysis/HiveTypeCoercionSuite.scala | 61 ++++++++++++++ .../expressions/ComplexTypeSuite.scala | 40 +++++++++ .../org/apache/spark/sql/functions.scala | 11 +++ .../spark/sql/DataFrameComplexTypeSuite.scala | 8 +- .../spark/sql/DataFrameFunctionsSuite.scala | 15 ++-- .../spark/sql/hive/ExpressionToSQLSuite.scala | 1 + 12 files changed, 277 insertions(+), 19 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index dee3d536be..f5d959ef98 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1498,6 +1498,26 @@ def translate(srcCol, matching, replace): # ---------------------- Collection functions ------------------------------ +@ignore_unicode_prefix +@since(2.0) +def create_map(*cols): + """Creates a new map column. + + :param cols: list of column names (string) or list of :class:`Column` expressions that grouped + as key-value pairs, e.g. (key1, value1, key2, value2, ...). + + >>> df.select(create_map('name', 'age').alias("map")).collect() + [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})] + >>> df.select(create_map([df.name, df.age]).alias("map")).collect() + [Row(map={u'Alice': 2}), Row(map={u'Bob': 5})] + """ + sc = SparkContext._active_spark_context + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + jc = sc._jvm.functions.map(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + @since(1.4) def array(*cols): """Creates a new array column. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 26bb96eb08..f584a4b73a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -126,6 +126,7 @@ object FunctionRegistry { expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), expression[Least]("least"), + expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[NaNvl]("nanvl"), expression[Coalesce]("nvl"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 0f85f44ffa..823d2495fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -160,6 +160,9 @@ object HiveTypeCoercion { }) } + private def haveSameType(exprs: Seq[Expression]): Boolean = + exprs.map(_.dataType).distinct.length == 1 + /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to * instances higher in the query tree. @@ -443,13 +446,37 @@ object HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 => + case a @ CreateArray(children) if !haveSameType(children) => val types = children.map(_.dataType) findTightestCommonTypeAndPromoteToString(types) match { case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) case None => a } + case m @ CreateMap(children) if m.keys.length == m.values.length && + (!haveSameType(m.keys) || !haveSameType(m.values)) => + val newKeys = if (haveSameType(m.keys)) { + m.keys + } else { + val types = m.keys.map(_.dataType) + findTightestCommonTypeAndPromoteToString(types) match { + case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) + case None => m.keys + } + } + + val newValues = if (haveSameType(m.values)) { + m.values + } else { + val types = m.values.map(_.dataType) + findTightestCommonTypeAndPromoteToString(types) match { + case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) + case None => m.values + } + } + + CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) }) + // Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows. case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) @@ -468,21 +495,21 @@ object HiveTypeCoercion { // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. - case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 => + case c @ Coalesce(es) if !haveSameType(es) => val types = es.map(_.dataType) findWiderCommonType(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } - case g @ Greatest(children) if children.map(_.dataType).distinct.size > 1 => + case g @ Greatest(children) if !haveSameType(children) => val types = children.map(_.dataType) findTightestCommonType(types) match { case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) case None => g } - case l @ Least(children) if children.map(_.dataType).distinct.size > 1 => + case l @ Least(children) if !haveSameType(children) => val types = children.map(_.dataType) findTightestCommonType(types) match { case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index efd75295b2..c299586dde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -69,6 +69,87 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def prettyName: String = "array" } +/** + * Returns a catalyst Map containing the evaluation of all children expressions as keys and values. + * The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...) + */ +case class CreateMap(children: Seq[Expression]) extends Expression { + private[sql] lazy val keys = children.indices.filter(_ % 2 == 0).map(children) + private[sql] lazy val values = children.indices.filter(_ % 2 != 0).map(children) + + override def foldable: Boolean = children.forall(_.foldable) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.size % 2 != 0) { + TypeCheckResult.TypeCheckFailure(s"$prettyName expects an positive even number of arguments.") + } else if (keys.map(_.dataType).distinct.length > 1) { + TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " + + "type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else if (values.map(_.dataType).distinct.length > 1) { + TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " + + "type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]")) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def dataType: DataType = { + MapType( + keyType = keys.headOption.map(_.dataType).getOrElse(NullType), + valueType = values.headOption.map(_.dataType).getOrElse(NullType), + valueContainsNull = values.exists(_.nullable)) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + val keyArray = keys.map(_.eval(input)).toArray + if (keyArray.contains(null)) { + throw new RuntimeException("Cannot use null as map key!") + } + val valueArray = values.map(_.eval(input)).toArray + new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) + } + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val arrayClass = classOf[GenericArrayData].getName + val mapClass = classOf[ArrayBasedMapData].getName + val keyArray = ctx.freshName("keyArray") + val valueArray = ctx.freshName("valueArray") + val keyData = s"new $arrayClass($keyArray)" + val valueData = s"new $arrayClass($valueArray)" + s""" + final boolean ${ev.isNull} = false; + final Object[] $keyArray = new Object[${keys.size}]; + final Object[] $valueArray = new Object[${values.size}]; + """ + keys.zipWithIndex.map { + case (key, i) => + val eval = key.gen(ctx) + s""" + ${eval.code} + if (${eval.isNull}) { + throw new RuntimeException("Cannot use null as map key!"); + } else { + $keyArray[$i] = ${eval.value}; + } + """ + }.mkString("\n") + values.zipWithIndex.map { + case (value, i) => + val eval = value.gen(ctx) + s""" + ${eval.code} + if (${eval.isNull}) { + $valueArray[$i] = null; + } else { + $valueArray[$i] = ${eval.value}; + } + """ + }.mkString("\n") + s"final MapData ${ev.value} = new $mapClass($keyData, $valueData);" + } + + override def prettyName: String = "map" +} + /** * Returns a Row containing the evaluation of all children expressions. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index d85b72ed83..d46f03ad8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -24,7 +24,6 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte override def copy(): MapData = new ArrayBasedMapData(keyArray.copy(), valueArray.copy()) - // We need to check equality of map type in tests. override def equals(o: Any): Boolean = { if (!o.isInstanceOf[ArrayBasedMapData]) { return false @@ -35,11 +34,11 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte return false } - ArrayBasedMapData.toScalaMap(this) == ArrayBasedMapData.toScalaMap(other) + this.keyArray == other.keyArray && this.valueArray == other.valueArray } override def hashCode: Int = { - ArrayBasedMapData.toScalaMap(this).hashCode() + keyArray.hashCode() * 37 + valueArray.hashCode() } override def toString: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 92c8496fde..ace6e10c6e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -173,13 +173,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Only foldable StringType expressions are allowed to appear at odd position") + "Only foldable StringType expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), - "Only foldable StringType expressions are allowed to appear at odd position") + "Only foldable StringType expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), - "Field name should not be null") + "Field name should not be null") + } + + test("check types for CreateMap") { + assertError(CreateMap(Seq("a", "b", 2.0)), "even number of arguments") + assertError( + CreateMap(Seq('intField, 'stringField, 'booleanField, 'stringField)), + "keys of function map should all be the same type") + assertError( + CreateMap(Seq('stringField, 'intField, 'stringField, 'booleanField)), + "values of function map should all be the same type") } test("check types for ROUND") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 6f289dcc47..883ef48984 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -250,6 +250,67 @@ class HiveTypeCoercionSuite extends PlanTest { :: Nil)) } + test("CreateArray casts") { + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + CreateArray(Literal(1.0) + :: Literal(1) + :: Literal.create(1.0, FloatType) + :: Nil), + CreateArray(Cast(Literal(1.0), DoubleType) + :: Cast(Literal(1), DoubleType) + :: Cast(Literal.create(1.0, FloatType), DoubleType) + :: Nil)) + + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + CreateArray(Literal(1.0) + :: Literal(1) + :: Literal("a") + :: Nil), + CreateArray(Cast(Literal(1.0), StringType) + :: Cast(Literal(1), StringType) + :: Cast(Literal("a"), StringType) + :: Nil)) + } + + test("CreateMap casts") { + // type coercion for map keys + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + CreateMap(Literal(1) + :: Literal("a") + :: Literal.create(2.0, FloatType) + :: Literal("b") + :: Nil), + CreateMap(Cast(Literal(1), FloatType) + :: Literal("a") + :: Cast(Literal.create(2.0, FloatType), FloatType) + :: Literal("b") + :: Nil)) + // type coercion for map values + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + CreateMap(Literal(1) + :: Literal("a") + :: Literal(2) + :: Literal(3.0) + :: Nil), + CreateMap(Literal(1) + :: Cast(Literal("a"), StringType) + :: Literal(2) + :: Cast(Literal(3.0), StringType) + :: Nil)) + // type coercion for both map keys and values + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + CreateMap(Literal(1) + :: Literal("a") + :: Literal(2.0) + :: Literal(3.0) + :: Nil), + CreateMap(Cast(Literal(1), DoubleType) + :: Cast(Literal("a"), StringType) + :: Cast(Literal(2.0), DoubleType) + :: Cast(Literal(3.0), StringType) + :: Nil)) + } + test("greatest/least cast") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { ruleTest(HiveTypeCoercion.FunctionArgumentConversion, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 9c1688b261..7c009a7360 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -134,6 +134,46 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) } + test("CreateMap") { + def interlace(keys: Seq[Literal], values: Seq[Literal]): Seq[Literal] = { + keys.zip(values).flatMap { case (k, v) => Seq(k, v) } + } + + def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { + // catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. + scala.collection.immutable.ListMap(keys.zip(values): _*) + } + + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val strSeq = intSeq.map(_.toString) + checkEvaluation(CreateMap(Nil), Map.empty) + checkEvaluation( + CreateMap(interlace(intSeq.map(Literal(_)), longSeq.map(Literal(_)))), + createMap(intSeq, longSeq)) + checkEvaluation( + CreateMap(interlace(strSeq.map(Literal(_)), longSeq.map(Literal(_)))), + createMap(strSeq, longSeq)) + checkEvaluation( + CreateMap(interlace(longSeq.map(Literal(_)), strSeq.map(Literal(_)))), + createMap(longSeq, strSeq)) + + val strWithNull = strSeq.drop(1).map(Literal(_)) :+ Literal.create(null, StringType) + checkEvaluation( + CreateMap(interlace(intSeq.map(Literal(_)), strWithNull)), + createMap(intSeq, strWithNull.map(_.value))) + intercept[RuntimeException] { + checkEvaluationWithoutCodegen( + CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), + null, null) + } + intercept[RuntimeException] { + checkEvalutionWithUnsafeProjection( + CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), + null, null) + } + } + test("CreateStruct") { val row = create_row(1, 2, 3) val c1 = 'a.int.at(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 304d747d4f..8abb9d7e4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -904,6 +904,17 @@ object functions { array((colName +: colNames).map(col) : _*) } + /** + * Creates a new map column. The input columns must be grouped as key-value pairs, e.g. + * (key1, value1, key2, value2, ...). The key columns must all have the same data type, and can't + * be null. The value columns must all have the same data type. + * + * @group normal_funcs + * @since 2.0 + */ + @scala.annotation.varargs + def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) } + /** * Marks a DataFrame as small enough for use in broadcast joins. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index b76fc73b7f..72f676e622 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -41,7 +41,13 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { test("UDF on array") { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") - df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() + df.select(array($"a").as("s")).select(f($"s".getItem(0))).collect() + } + + test("UDF on map") { + val f = udf((a: String) => a) + val df = Seq("a" -> 1).toDF("a", "b") + df.select(map($"a", $"b").as("s")).select(f($"s".getItem("a"))).collect() } test("SPARK-12477 accessing null element in array field") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 2aa6f8d4ac..746e25a0c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -44,15 +44,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val expectedType = ArrayType(IntegerType, containsNull = false) assert(row.schema(0).dataType === expectedType) - assert(row.getAs[Seq[Int]](0) === Seq(0, 2)) + assert(row.getSeq[Int](0) === Seq(0, 2)) } - // Turn this on once we add a rule to the analyzer to throw a friendly exception - ignore("array: throw exception if putting columns of different types into an array") { - val df = Seq((0, "str")).toDF("a", "b") - intercept[AnalysisException] { - df.select(array("a", "b")) - } + test("map with column expressions") { + val df = Seq(1 -> "a").toDF("a", "b") + val row = df.select(map($"a" + 1, $"b")).first() + + val expectedType = MapType(IntegerType, StringType, valueContainsNull = true) + assert(row.schema(0).dataType === expectedType) + assert(row.getMap[Int, String](0) === Map(2 -> "a")) } test("struct with column name") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala index 4c9c48a25c..75930086ff 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala @@ -100,6 +100,7 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSqlGeneration("SELECT isnull(null), isnull('a')") checkSqlGeneration("SELECT isnotnull(null), isnotnull('a')") checkSqlGeneration("SELECT least(1,null,3)") + checkSqlGeneration("SELECT map(1, 'a', 2, 'b')") checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)") checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2") checkSqlGeneration("SELECT nvl(null, 1, 2)") -- GitLab