Skip to content
Snippets Groups Projects
Commit 958a0ec8 authored by Jia Li's avatar Jia Li Committed by Michael Armbrust
Browse files

[SPARK-11277][SQL] sort_array throws exception scala.MatchError

I'm new to spark. I was trying out the sort_array function then hit this exception. I looked into the spark source code. I found the root cause is that sort_array does not check for an array of NULLs. It's not meaningful to sort an array of entirely NULLs anyway.

I'm adding a check on the input array type to SortArray. If the array consists of NULLs entirely, there is no need to sort such array. I have also added a test case for this.

Please help to review my fix. Thanks!

Author: Jia Li <jiali@us.ibm.com>

Closes #9247 from jliwork/SPARK-11277.
parent 17f49992
No related branches found
No related tags found
No related merge requests found
...@@ -68,6 +68,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) ...@@ -68,6 +68,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
private lazy val lt: Comparator[Any] = { private lazy val lt: Comparator[Any] = {
val ordering = base.dataType match { val ordering = base.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
} }
new Comparator[Any]() { new Comparator[Any]() {
...@@ -89,6 +90,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) ...@@ -89,6 +90,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
private lazy val gt: Comparator[Any] = { private lazy val gt: Comparator[Any] = {
val ordering = base.dataType match { val ordering = base.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
} }
new Comparator[Any]() { new Comparator[Any]() {
...@@ -109,7 +111,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) ...@@ -109,7 +111,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
override def nullSafeEval(array: Any, ascending: Any): Any = { override def nullSafeEval(array: Any, ascending: Any): Any = {
val elementType = base.dataType.asInstanceOf[ArrayType].elementType val elementType = base.dataType.asInstanceOf[ArrayType].elementType
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) if (elementType != NullType) {
java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt)
}
new GenericArrayData(data.asInstanceOf[Array[Any]]) new GenericArrayData(data.asInstanceOf[Array[Any]])
} }
......
...@@ -49,6 +49,7 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { ...@@ -49,6 +49,7 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType))
val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType))
val a4 = Literal.create(Seq(null, null), ArrayType(NullType))
checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3))
checkEvaluation(new SortArray(a1), Seq[Integer]()) checkEvaluation(new SortArray(a1), Seq[Integer]())
...@@ -64,6 +65,12 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { ...@@ -64,6 +65,12 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null))
checkEvaluation(Literal.create(null, ArrayType(StringType)), null) checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
checkEvaluation(new SortArray(a4), Seq(null, null))
val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS)
checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2)))
} }
test("Array contains") { test("Array contains") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment