Skip to content
Snippets Groups Projects
Commit 54b27c17 authored by Dongjoon Hyun's avatar Dongjoon Hyun Committed by Wenchen Fan
Browse files

[SPARK-16278][SPARK-16279][SQL] Implement map_keys/map_values SQL functions

## What changes were proposed in this pull request?

This PR adds `map_keys` and `map_values` SQL functions in order to remove Hive fallback.

## How was this patch tested?

Pass the Jenkins tests including new testcases.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #13967 from dongjoon-hyun/SPARK-16278.
parent ea990f96
No related branches found
No related tags found
No related merge requests found
......@@ -171,6 +171,8 @@ object FunctionRegistry {
expression[IsNotNull]("isnotnull"),
expression[Least]("least"),
expression[CreateMap]("map"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[CreateNamedStruct]("named_struct"),
expression[NaNvl]("nanvl"),
expression[NullIf]("nullif"),
......
......@@ -43,6 +43,54 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
}
}
/**
* Returns an unordered array containing the keys of the map.
*/
@ExpressionDescription(
usage = "_FUNC_(map) - Returns an unordered array containing the keys of the map.",
extended = " > SELECT _FUNC_(map(1, 'a', 2, 'b'));\n [1,2]")
case class MapKeys(child: Expression)
extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
override def dataType: DataType = ArrayType(child.dataType.asInstanceOf[MapType].keyType)
override def nullSafeEval(map: Any): Any = {
map.asInstanceOf[MapData].keyArray()
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).keyArray();")
}
override def prettyName: String = "map_keys"
}
/**
* Returns an unordered array containing the values of the map.
*/
@ExpressionDescription(
usage = "_FUNC_(map) - Returns an unordered array containing the values of the map.",
extended = " > SELECT _FUNC_(map(1, 'a', 2, 'b'));\n [\"a\",\"b\"]")
case class MapValues(child: Expression)
extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
override def dataType: DataType = ArrayType(child.dataType.asInstanceOf[MapType].valueType)
override def nullSafeEval(map: Any): Any = {
map.asInstanceOf[MapData].valueArray()
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).valueArray();")
}
override def prettyName: String = "map_values"
}
/**
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
......
......@@ -44,6 +44,19 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
}
test("MapKeys/MapValues") {
val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType))
val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
val m2 = Literal.create(null, MapType(StringType, StringType))
checkEvaluation(MapKeys(m0), Seq("a", "b"))
checkEvaluation(MapValues(m0), Seq("1", "2"))
checkEvaluation(MapKeys(m1), Seq())
checkEvaluation(MapValues(m1), Seq())
checkEvaluation(MapKeys(m2), null)
checkEvaluation(MapValues(m2), null)
}
test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
......
......@@ -352,6 +352,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}
test("map_keys/map_values function") {
val df = Seq(
(Map[Int, Int](1 -> 100, 2 -> 200), "x"),
(Map[Int, Int](), "y"),
(Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), "z")
).toDF("a", "b")
checkAnswer(
df.selectExpr("map_keys(a)"),
Seq(Row(Seq(1, 2)), Row(Seq.empty), Row(Seq(1, 2, 3)))
)
checkAnswer(
df.selectExpr("map_values(a)"),
Seq(Row(Seq(100, 200)), Row(Seq.empty), Row(Seq(100, 200, 300)))
)
}
test("array contains function") {
val df = Seq(
(Seq[Int](1, 2), "x"),
......
......@@ -239,7 +239,6 @@ private[sql] class HiveSessionCatalog(
// str_to_map, windowingtablefunction.
private val hiveFunctions = Seq(
"hash", "java_method", "histogram_numeric",
"map_keys", "map_values",
"parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map",
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
"xpath_number", "xpath_short", "xpath_string",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment