Skip to content
Snippets Groups Projects
Commit d0d28507 authored by Dongjoon Hyun's avatar Dongjoon Hyun Committed by Wenchen Fan
Browse files

[SPARK-16286][SQL] Implement stack table generating function

## What changes were proposed in this pull request?

This PR implements `stack` table generating function.

## How was this patch tested?

Pass the Jenkins tests including new testcases.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #14033 from dongjoon-hyun/SPARK-16286.
parent fdde7d0a
No related branches found
No related tags found
No related merge requests found
......@@ -182,6 +182,7 @@ object FunctionRegistry {
expression[PosExplode]("posexplode"),
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[Stack]("stack"),
expression[CreateStruct]("struct"),
expression[CaseWhen]("when"),
......
......@@ -93,6 +93,59 @@ case class UserDefinedGenerator(
override def toString: String = s"UserDefinedGenerator(${children.mkString(",")})"
}
/**
* Separate v1, ..., vk into n rows. Each row will have k/n columns. n must be constant.
* {{{
* SELECT stack(2, 1, 2, 3) ->
* 1 2
* 3 NULL
* }}}
*/
@ExpressionDescription(
usage = "_FUNC_(n, v1, ..., vk) - Separate v1, ..., vk into n rows.",
extended = "> SELECT _FUNC_(2, 1, 2, 3);\n [1,2]\n [3,null]")
case class Stack(children: Seq[Expression])
extends Expression with Generator with CodegenFallback {
private lazy val numRows = children.head.eval().asInstanceOf[Int]
private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt
override def checkInputDataTypes(): TypeCheckResult = {
if (children.length <= 1) {
TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.")
} else if (children.head.dataType != IntegerType || !children.head.foldable || numRows < 1) {
TypeCheckResult.TypeCheckFailure("The number of rows must be a positive constant integer.")
} else {
for (i <- 1 until children.length) {
val j = (i - 1) % numFields
if (children(i).dataType != elementSchema.fields(j).dataType) {
return TypeCheckResult.TypeCheckFailure(
s"Argument ${j + 1} (${elementSchema.fields(j).dataType}) != " +
s"Argument $i (${children(i).dataType})")
}
}
TypeCheckResult.TypeCheckSuccess
}
}
override def elementSchema: StructType =
StructType(children.tail.take(numFields).zipWithIndex.map {
case (e, index) => StructField(s"col$index", e.dataType)
})
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
val values = children.tail.map(_.eval(input)).toArray
for (row <- 0 until numRows) yield {
val fields = new Array[Any](numFields)
for (col <- 0 until numFields) {
val index = row * numFields + col
fields.update(col, if (index < values.length) values(index) else null)
}
InternalRow(fields: _*)
}
}
}
/**
* A base class for Explode and PosExplode
*/
......
......@@ -63,4 +63,22 @@ class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
))),
correct_answer)
}
test("stack") {
checkTuple(Stack(Seq(1, 1).map(Literal(_))), Seq(create_row(1)))
checkTuple(Stack(Seq(1, 1, 2).map(Literal(_))), Seq(create_row(1, 2)))
checkTuple(Stack(Seq(2, 1, 2).map(Literal(_))), Seq(create_row(1), create_row(2)))
checkTuple(Stack(Seq(2, 1, 2, 3).map(Literal(_))), Seq(create_row(1, 2), create_row(3, null)))
checkTuple(Stack(Seq(3, 1, 2, 3).map(Literal(_))), Seq(1, 2, 3).map(create_row(_)))
checkTuple(Stack(Seq(4, 1, 2, 3).map(Literal(_))), Seq(1, 2, 3, null).map(create_row(_)))
checkTuple(
Stack(Seq(3, 1, 1.0, "a", 2, 2.0, "b", 3, 3.0, "c").map(Literal(_))),
Seq(create_row(1, 1.0, "a"), create_row(2, 2.0, "b"), create_row(3, 3.0, "c")))
assert(Stack(Seq(Literal(1))).checkInputDataTypes().isFailure)
assert(Stack(Seq(Literal(1.0))).checkInputDataTypes().isFailure)
assert(Stack(Seq(Literal(1), Literal(1), Literal(1.0))).checkInputDataTypes().isSuccess)
assert(Stack(Seq(Literal(2), Literal(1), Literal(1.0))).checkInputDataTypes().isFailure)
}
}
......@@ -23,6 +23,59 @@ import org.apache.spark.sql.test.SharedSQLContext
class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("stack") {
val df = spark.range(1)
// Empty DataFrame suppress the result generation
checkAnswer(spark.emptyDataFrame.selectExpr("stack(1, 1, 2, 3)"), Nil)
// Rows & columns
checkAnswer(df.selectExpr("stack(1, 1, 2, 3)"), Row(1, 2, 3) :: Nil)
checkAnswer(df.selectExpr("stack(2, 1, 2, 3)"), Row(1, 2) :: Row(3, null) :: Nil)
checkAnswer(df.selectExpr("stack(3, 1, 2, 3)"), Row(1) :: Row(2) :: Row(3) :: Nil)
checkAnswer(df.selectExpr("stack(4, 1, 2, 3)"), Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
// Various column types
checkAnswer(df.selectExpr("stack(3, 1, 1.1, 'a', 2, 2.2, 'b', 3, 3.3, 'c')"),
Row(1, 1.1, "a") :: Row(2, 2.2, "b") :: Row(3, 3.3, "c") :: Nil)
// Repeat generation at every input row
checkAnswer(spark.range(2).selectExpr("stack(2, 1, 2, 3)"),
Row(1, 2) :: Row(3, null) :: Row(1, 2) :: Row(3, null) :: Nil)
// The first argument must be a positive constant integer.
val m = intercept[AnalysisException] {
df.selectExpr("stack(1.1, 1, 2, 3)")
}.getMessage
assert(m.contains("The number of rows must be a positive constant integer."))
val m2 = intercept[AnalysisException] {
df.selectExpr("stack(-1, 1, 2, 3)")
}.getMessage
assert(m2.contains("The number of rows must be a positive constant integer."))
// The data for the same column should have the same type.
val m3 = intercept[AnalysisException] {
df.selectExpr("stack(2, 1, '2.2')")
}.getMessage
assert(m3.contains("data type mismatch: Argument 1 (IntegerType) != Argument 2 (StringType)"))
// stack on column data
val df2 = Seq((2, 1, 2, 3)).toDF("n", "a", "b", "c")
checkAnswer(df2.selectExpr("stack(2, a, b, c)"), Row(1, 2) :: Row(3, null) :: Nil)
val m4 = intercept[AnalysisException] {
df2.selectExpr("stack(n, a, b, c)")
}.getMessage
assert(m4.contains("The number of rows must be a positive constant integer."))
val df3 = Seq((2, 1, 2.0)).toDF("n", "a", "b")
val m5 = intercept[AnalysisException] {
df3.selectExpr("stack(2, a, b)")
}.getMessage
assert(m5.contains("data type mismatch: Argument 1 (IntegerType) != Argument 2 (DoubleType)"))
}
test("single explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
......
......@@ -236,7 +236,7 @@ private[sql] class HiveSessionCatalog(
// str_to_map, windowingtablefunction.
private val hiveFunctions = Seq(
"hash", "java_method", "histogram_numeric",
"parse_url", "percentile", "percentile_approx", "reflect", "sentences", "stack", "str_to_map",
"parse_url", "percentile", "percentile_approx", "reflect", "sentences", "str_to_map",
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
"xpath_number", "xpath_short", "xpath_string"
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment