diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 0fcfb97d2bd905944810c1aef277bcf0097d2d00..2f685c5f9cb51b5c7ba8396b754a1e0b3bd5dfff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData.DecimalData import org.apache.spark.sql.types.DecimalType case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) @@ -69,6 +70,14 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0))) ) + val decimalDataWithNulls = sqlContext.sparkContext.parallelize( + DecimalData(1, 1) :: + DecimalData(1, null) :: + DecimalData(2, 1) :: + DecimalData(2, null) :: + DecimalData(3, 1) :: + DecimalData(3, 2) :: + DecimalData(null, 2) :: Nil).toDF() checkAnswer( decimalDataWithNulls.groupBy("a").agg(sum("b")), Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(1.0)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index c5f25fa1df3b190f282a8a11a4d959e5b9e63d6b..7fa6760b71c8bd26c9934020caf1a7459813512a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -103,19 +103,6 @@ private[sql] trait SQLTestData { self => df } - protected lazy val decimalDataWithNulls: DataFrame = { - val df = sqlContext.sparkContext.parallelize( - DecimalDataWithNulls(1, 1) :: - DecimalDataWithNulls(1, null) :: - DecimalDataWithNulls(2, 1) :: - DecimalDataWithNulls(2, null) :: - DecimalDataWithNulls(3, 1) :: - DecimalDataWithNulls(3, 2) :: - DecimalDataWithNulls(null, 2) :: Nil).toDF() - df.registerTempTable("decimalDataWithNulls") - df - } - protected lazy val binaryData: DataFrame = { val df = sqlContext.sparkContext.parallelize( BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) :: @@ -280,7 +267,6 @@ private[sql] trait SQLTestData { self => negativeData largeAndSmallInts decimalData - decimalDataWithNulls binaryData upperCaseData lowerCaseData @@ -310,7 +296,6 @@ private[sql] object SQLTestData { case class TestData3(a: Int, b: Option[Int]) case class LargeAndSmallInts(a: Int, b: Int) case class DecimalData(a: BigDecimal, b: BigDecimal) - case class DecimalDataWithNulls(a: BigDecimal, b: BigDecimal) case class BinaryData(a: Array[Byte], b: Int) case class UpperCaseData(N: Int, L: String) case class LowerCaseData(n: Int, l: String)