Skip to content
Snippets Groups Projects
Commit 88134e73 authored by Dongjoon Hyun's avatar Dongjoon Hyun Committed by Wenchen Fan
Browse files

[SPARK-16288][SQL] Implement inline table generating function

## What changes were proposed in this pull request?

This PR implements `inline` table generating function.

## How was this patch tested?

Pass the Jenkins tests with new testcase.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #13976 from dongjoon-hyun/SPARK-16288.
parent 54b27c17
No related branches found
No related tags found
No related merge requests found
......@@ -165,6 +165,7 @@ object FunctionRegistry {
expression[Explode]("explode"),
expression[Greatest]("greatest"),
expression[If]("if"),
expression[Inline]("inline"),
expression[IsNaN]("isnan"),
expression[IfNull]("ifnull"),
expression[IsNull]("isnull"),
......
......@@ -195,3 +195,38 @@ case class Explode(child: Expression) extends ExplodeBase(child, position = fals
extended = "> SELECT _FUNC_(array(10,20));\n 0\t10\n 1\t20")
// scalastyle:on line.size.limit
case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)
/**
* Explodes an array of structs into a table.
*/
@ExpressionDescription(
usage = "_FUNC_(a) - Explodes an array of structs into a table.",
extended = "> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));\n [1,a]\n [2,b]")
case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
override def children: Seq[Expression] = child :: Nil
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(et, _) if et.isInstanceOf[StructType] =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName should be array of struct type, not ${child.dataType}")
}
override def elementSchema: StructType = child.dataType match {
case ArrayType(et : StructType, _) => et
}
private lazy val numFields = elementSchema.fields.length
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
val inputArray = child.eval(input).asInstanceOf[ArrayData]
if (inputArray == null) {
Nil
} else {
for (i <- 0 until inputArray.numElements())
yield inputArray.getStruct(i, numFields)
}
}
}
......@@ -19,53 +19,48 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql.types._
class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
private def checkTuple(actual: ExplodeBase, expected: Seq[InternalRow]): Unit = {
assert(actual.eval(null).toSeq === expected)
private def checkTuple(actual: Expression, expected: Seq[InternalRow]): Unit = {
assert(actual.eval(null).asInstanceOf[TraversableOnce[InternalRow]].toSeq === expected)
}
private final val int_array = Seq(1, 2, 3)
private final val str_array = Seq("a", "b", "c")
private final val empty_array = CreateArray(Seq.empty)
private final val int_array = CreateArray(Seq(1, 2, 3).map(Literal(_)))
private final val str_array = CreateArray(Seq("a", "b", "c").map(Literal(_)))
test("explode") {
val int_correct_answer = Seq(Seq(1), Seq(2), Seq(3))
val str_correct_answer = Seq(
Seq(UTF8String.fromString("a")),
Seq(UTF8String.fromString("b")),
Seq(UTF8String.fromString("c")))
val int_correct_answer = Seq(create_row(1), create_row(2), create_row(3))
val str_correct_answer = Seq(create_row("a"), create_row("b"), create_row("c"))
checkTuple(
Explode(CreateArray(Seq.empty)),
Seq.empty)
checkTuple(Explode(empty_array), Seq.empty)
checkTuple(Explode(int_array), int_correct_answer)
checkTuple(Explode(str_array), str_correct_answer)
}
checkTuple(
Explode(CreateArray(int_array.map(Literal(_)))),
int_correct_answer.map(InternalRow.fromSeq(_)))
test("posexplode") {
val int_correct_answer = Seq(create_row(0, 1), create_row(1, 2), create_row(2, 3))
val str_correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c"))
checkTuple(
Explode(CreateArray(str_array.map(Literal(_)))),
str_correct_answer.map(InternalRow.fromSeq(_)))
checkTuple(PosExplode(CreateArray(Seq.empty)), Seq.empty)
checkTuple(PosExplode(int_array), int_correct_answer)
checkTuple(PosExplode(str_array), str_correct_answer)
}
test("posexplode") {
val int_correct_answer = Seq(Seq(0, 1), Seq(1, 2), Seq(2, 3))
val str_correct_answer = Seq(
Seq(0, UTF8String.fromString("a")),
Seq(1, UTF8String.fromString("b")),
Seq(2, UTF8String.fromString("c")))
test("inline") {
val correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c"))
checkTuple(
PosExplode(CreateArray(Seq.empty)),
Inline(Literal.create(Array(), ArrayType(new StructType().add("id", LongType)))),
Seq.empty)
checkTuple(
PosExplode(CreateArray(int_array.map(Literal(_)))),
int_correct_answer.map(InternalRow.fromSeq(_)))
checkTuple(
PosExplode(CreateArray(str_array.map(Literal(_)))),
str_correct_answer.map(InternalRow.fromSeq(_)))
Inline(CreateArray(Seq(
CreateStruct(Seq(Literal(0), Literal("a"))),
CreateStruct(Seq(Literal(1), Literal("b"))),
CreateStruct(Seq(Literal(2), Literal("c")))
))),
correct_answer)
}
}
......@@ -89,4 +89,64 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
Row(3) :: Nil)
}
test("inline raises exception on array of null type") {
val m = intercept[AnalysisException] {
spark.range(2).selectExpr("inline(array())")
}.getMessage
assert(m.contains("data type mismatch"))
}
test("inline with empty table") {
checkAnswer(
spark.range(0).selectExpr("inline(array(struct(10, 100)))"),
Nil)
}
test("inline on literal") {
checkAnswer(
spark.range(2).selectExpr("inline(array(struct(10, 100), struct(20, 200), struct(30, 300)))"),
Row(10, 100) :: Row(20, 200) :: Row(30, 300) ::
Row(10, 100) :: Row(20, 200) :: Row(30, 300) :: Nil)
}
test("inline on column") {
val df = Seq((1, 2)).toDF("a", "b")
checkAnswer(
df.selectExpr("inline(array(struct(a), struct(a)))"),
Row(1) :: Row(1) :: Nil)
checkAnswer(
df.selectExpr("inline(array(struct(a, b), struct(a, b)))"),
Row(1, 2) :: Row(1, 2) :: Nil)
// Spark think [struct<a:int>, struct<b:int>] is heterogeneous due to name difference.
val m = intercept[AnalysisException] {
df.selectExpr("inline(array(struct(a), struct(b)))")
}.getMessage
assert(m.contains("data type mismatch"))
checkAnswer(
df.selectExpr("inline(array(struct(a), named_struct('a', b)))"),
Row(1) :: Row(2) :: Nil)
// Spark think [struct<a:int>, struct<col1:int>] is heterogeneous due to name difference.
val m2 = intercept[AnalysisException] {
df.selectExpr("inline(array(struct(a), struct(2)))")
}.getMessage
assert(m2.contains("data type mismatch"))
checkAnswer(
df.selectExpr("inline(array(struct(a), named_struct('a', 2)))"),
Row(1) :: Row(2) :: Nil)
checkAnswer(
df.selectExpr("struct(a)").selectExpr("inline(array(*))"),
Row(1) :: Nil)
checkAnswer(
df.selectExpr("array(struct(a), named_struct('a', b))").selectExpr("inline(*)"),
Row(1) :: Row(2) :: Nil)
}
}
......@@ -241,9 +241,6 @@ private[sql] class HiveSessionCatalog(
"hash", "java_method", "histogram_numeric",
"parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map",
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
"xpath_number", "xpath_short", "xpath_string",
// table generating function
"inline"
"xpath_number", "xpath_short", "xpath_string"
)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment