Skip to content
Snippets Groups Projects
Commit 785f9558 authored by Yin Huai's avatar Yin Huai Committed by Michael Armbrust
Browse files

[SPARK-6887][SQL] ColumnBuilder misses FloatType

https://issues.apache.org/jira/browse/SPARK-6887

Author: Yin Huai <yhuai@databricks.com>

Closes #5499 from yhuai/inMemFloat and squashes the following commits:

84cba38 [Yin Huai] Add test.
4b75ba6 [Yin Huai] Add FloatType back.
parent e3e4e9a3
No related branches found
No related tags found
No related merge requests found
...@@ -153,6 +153,7 @@ private[sql] object ColumnBuilder { ...@@ -153,6 +153,7 @@ private[sql] object ColumnBuilder {
val builder: ColumnBuilder = dataType match { val builder: ColumnBuilder = dataType match {
case IntegerType => new IntColumnBuilder case IntegerType => new IntColumnBuilder
case LongType => new LongColumnBuilder case LongType => new LongColumnBuilder
case FloatType => new FloatColumnBuilder
case DoubleType => new DoubleColumnBuilder case DoubleType => new DoubleColumnBuilder
case BooleanType => new BooleanColumnBuilder case BooleanType => new BooleanColumnBuilder
case ByteType => new ByteColumnBuilder case ByteType => new ByteColumnBuilder
......
...@@ -104,9 +104,12 @@ object QueryTest { ...@@ -104,9 +104,12 @@ object QueryTest {
// Converts data to types that we can do equality comparison using Scala collections. // Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to // For BigDecimal type, the Scala type has a better definition of equality test (similar to
// Java's java.math.BigDecimal.compareTo). // Java's java.math.BigDecimal.compareTo).
// For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
// equality test.
val converted: Seq[Row] = answer.map { s => val converted: Seq[Row] = answer.map { s =>
Row.fromSeq(s.toSeq.map { Row.fromSeq(s.toSeq.map {
case d: java.math.BigDecimal => BigDecimal(d) case d: java.math.BigDecimal => BigDecimal(d)
case b: Array[Byte] => b.toSeq
case o => o case o => o
}) })
} }
......
...@@ -17,11 +17,13 @@ ...@@ -17,11 +17,13 @@
package org.apache.spark.sql.columnar package org.apache.spark.sql.columnar
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.TestData._ import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types.{DecimalType, Decimal} import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.sql.{QueryTest, TestData}
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
...@@ -132,4 +134,59 @@ class InMemoryColumnarQuerySuite extends QueryTest { ...@@ -132,4 +134,59 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT * FROM test_fixed_decimal"), sql("SELECT * FROM test_fixed_decimal"),
(1 to 10).map(i => Row(Decimal(i, 15, 10).toJavaBigDecimal))) (1 to 10).map(i => Row(Decimal(i, 15, 10).toJavaBigDecimal)))
} }
test("test different data types") {
// Create the schema.
val struct =
StructType(
StructField("f1", FloatType, true) ::
StructField("f2", ArrayType(BooleanType), true) :: Nil)
val dataTypes =
Seq(StringType, BinaryType, NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5),
DateType, TimestampType,
ArrayType(IntegerType), MapType(StringType, LongType), struct)
val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
StructField(s"col$index", dataType, true)
}
val allColumns = fields.map(_.name).mkString(",")
val schema = StructType(fields)
// Create a RDD for the schema
val rdd =
sparkContext.parallelize((1 to 100), 10).map { i =>
Row(
s"str${i}: test cache.",
s"binary${i}: test cache.".getBytes("UTF-8"),
null,
i % 2 == 0,
i.toByte,
i.toShort,
i,
Long.MaxValue - i.toLong,
(i + 0.25).toFloat,
(i + 0.75),
BigDecimal(Long.MaxValue.toString + ".12345"),
new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"),
new Date(i),
new Timestamp(i),
(1 to i).toSeq,
(0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
Row((i - 0.25).toFloat, (1 to i).toSeq))
}
createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
// Cache the table.
sql("cache table InMemoryCache_different_data_types")
// Make sure the table is indeed cached.
val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan
assert(
isCached("InMemoryCache_different_data_types"),
"InMemoryCache_different_data_types should be cached.")
// Issue a query and check the results.
checkAnswer(
sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"),
table("InMemoryCache_different_data_types").collect())
dropTempTable("InMemoryCache_different_data_types")
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment