Skip to content
Snippets Groups Projects
Commit d2e44d7d authored by Takeshi YAMAMURO's avatar Takeshi YAMAMURO Committed by Herman van Hovell
Browse files

[SPARK-16192][SQL] Add type checks in CollectSet

## What changes were proposed in this pull request?
`CollectSet` cannot have map-typed data because MapTypeData does not implement `equals`.
So, this pr is to add type checks in `CheckAnalysis`.

## How was this patch tested?
Added tests to check failures when we found map-typed data in `CollectSet`.

Author: Takeshi YAMAMURO <linguin.m.s@gmail.com>

Closes #13892 from maropu/SPARK-16192.
parent 9053054c
No related branches found
No related tags found
No related merge requests found
...@@ -73,9 +73,9 @@ trait CheckAnalysis extends PredicateHelper { ...@@ -73,9 +73,9 @@ trait CheckAnalysis extends PredicateHelper {
s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")
case g: Grouping => case g: Grouping =>
failAnalysis(s"grouping() can only be used with GroupingSets/Cube/Rollup") failAnalysis("grouping() can only be used with GroupingSets/Cube/Rollup")
case g: GroupingID => case g: GroupingID =>
failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup") failAnalysis("grouping_id() can only be used with GroupingSets/Cube/Rollup")
case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => case w @ WindowExpression(AggregateExpression(_, _, true, _), _) =>
failAnalysis(s"Distinct window functions are not supported: $w") failAnalysis(s"Distinct window functions are not supported: $w")
......
...@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate ...@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import scala.collection.generic.Growable import scala.collection.generic.Growable
import scala.collection.mutable import scala.collection.mutable
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
...@@ -107,6 +108,14 @@ case class CollectSet( ...@@ -107,6 +108,14 @@ case class CollectSet(
def this(child: Expression) = this(child, 0, 0) def this(child: Expression) = this(child, 0, 0)
override def checkInputDataTypes(): TypeCheckResult = {
if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure("collect_set() cannot have map type data")
}
}
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset) copy(mutableAggBufferOffset = newMutableAggBufferOffset)
......
...@@ -457,6 +457,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ...@@ -457,6 +457,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
) )
} }
test("collect_set functions cannot have maps") {
val df = Seq((1, 3, 0), (2, 3, 0), (3, 4, 1))
.toDF("a", "x", "y")
.select($"a", map($"x", $"y").as("b"))
val error = intercept[AnalysisException] {
df.select(collect_set($"a"), collect_set($"b"))
}
assert(error.message.contains("collect_set() cannot have map type data"))
}
test("SPARK-14664: Decimal sum/avg over window should work.") { test("SPARK-14664: Decimal sum/avg over window should work.") {
checkAnswer( checkAnswer(
spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment