Skip to content
Snippets Groups Projects
Commit b386943b authored by Kazuaki Ishizaki's avatar Kazuaki Ishizaki Committed by Andrew Or
Browse files

[SPARK-17680][SQL][TEST] Added test cases for InMemoryRelation

## What changes were proposed in this pull request?

This pull request adds test cases for the following cases:
- keep all data types with null or without null
- access `CachedBatch` disabling whole stage codegen
- access only some columns in `CachedBatch`

This PR is a part of https://github.com/apache/spark/pull/15219. Here are motivations to add these tests. When https://github.com/apache/spark/pull/15219 is enabled, the first two cases are handled by specialized (generated) code. The third one is a pitfall.

In general, even for now, it would be helpful to increase test coverage.
## How was this patch tested?

added test suites itself

Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com>

Closes #15462 from kiszk/columnartestsuites.
parent 81e3f971
No related branches found
No related tags found
No related merge requests found
......@@ -20,18 +20,96 @@ package org.apache.spark.sql.execution.columnar
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
import org.apache.spark.storage.StorageLevel._
class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
import testImplicits._
setupTestData()
private def cachePrimitiveTest(data: DataFrame, dataType: String) {
data.createOrReplaceTempView(s"testData$dataType")
val storageLevel = MEMORY_ONLY
val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan
val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None)
assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel)
inMemoryRelation.cachedColumnBuffers.collect().head match {
case _: CachedBatch =>
case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}")
}
checkAnswer(inMemoryRelation, data.collect().toSeq)
}
private def testPrimitiveType(nullability: Boolean): Unit = {
val dataTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DateType, TimestampType, DecimalType(25, 5), DecimalType(6, 5))
val schema = StructType(dataTypes.zipWithIndex.map { case (dataType, index) =>
StructField(s"col$index", dataType, nullability)
})
val rdd = spark.sparkContext.parallelize((1 to 10).map(i => Row(
if (nullability && i % 3 == 0) null else if (i % 2 == 0) true else false,
if (nullability && i % 3 == 0) null else i.toByte,
if (nullability && i % 3 == 0) null else i.toShort,
if (nullability && i % 3 == 0) null else i.toInt,
if (nullability && i % 3 == 0) null else i.toLong,
if (nullability && i % 3 == 0) null else (i + 0.25).toFloat,
if (nullability && i % 3 == 0) null else (i + 0.75).toDouble,
if (nullability && i % 3 == 0) null else new Date(i),
if (nullability && i % 3 == 0) null else new Timestamp(i * 1000000L),
if (nullability && i % 3 == 0) null else BigDecimal(Long.MaxValue.toString + ".12345"),
if (nullability && i % 3 == 0) null
else new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456")
)))
cachePrimitiveTest(spark.createDataFrame(rdd, schema), "primitivesDateTimeStamp")
}
private def tesNonPrimitiveType(nullability: Boolean): Unit = {
val struct = StructType(StructField("f1", FloatType, false) ::
StructField("f2", ArrayType(BooleanType), true) :: Nil)
val schema = StructType(Seq(
StructField("col0", StringType, nullability),
StructField("col1", ArrayType(IntegerType), nullability),
StructField("col2", ArrayType(ArrayType(IntegerType)), nullability),
StructField("col3", MapType(StringType, IntegerType), nullability),
StructField("col4", struct, nullability)
))
val rdd = spark.sparkContext.parallelize((1 to 10).map(i => Row(
if (nullability && i % 3 == 0) null else s"str${i}: test cache.",
if (nullability && i % 3 == 0) null else (i * 100 to i * 100 + i).toArray,
if (nullability && i % 3 == 0) null
else Array(Array(i, i + 1), Array(i * 100 + 1, i * 100, i * 100 + 2)),
if (nullability && i % 3 == 0) null else (i to i + i).map(j => s"key$j" -> j).toMap,
if (nullability && i % 3 == 0) null else Row((i + 0.25).toFloat, Seq(true, false, null))
)))
cachePrimitiveTest(spark.createDataFrame(rdd, schema), "StringArrayMapStruct")
}
test("primitive type with nullability:true") {
testPrimitiveType(true)
}
test("primitive type with nullability:false") {
testPrimitiveType(false)
}
test("non-primitive type with nullability:true") {
val schemaNull = StructType(Seq(StructField("col", NullType, true)))
val rddNull = spark.sparkContext.parallelize((1 to 10).map(i => Row(null)))
cachePrimitiveTest(spark.createDataFrame(rddNull, schemaNull), "Null")
tesNonPrimitiveType(true)
}
test("non-primitive type with nullability:false") {
tesNonPrimitiveType(false)
}
test("simple columnar query") {
val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
......@@ -58,6 +136,13 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
}.map(Row.fromTuple))
}
test("access only some column of the all of columns") {
val df = spark.range(1, 100).map(i => (i, (i + 1).toFloat)).toDF("i", "f")
df.cache
df.count // forced to build cache
assert(df.filter("f <= 10.0").count == 9)
}
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
......@@ -246,4 +331,63 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
assert(cached.batchStats.value === expectedAnswer.size * INT.defaultSize)
}
test("access primitive-type columns in CachedBatch without whole stage codegen") {
// whole stage codegen is not applied to a row with more than WHOLESTAGE_MAX_NUM_FIELDS fields
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "2") {
val data = Seq(null, true, 1.toByte, 3.toShort, 7, 15.toLong,
31.25.toFloat, 63.75, new Date(127), new Timestamp(255000000L), null)
val dataTypes = Seq(NullType, BooleanType, ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DateType, TimestampType, IntegerType)
val schemas = dataTypes.zipWithIndex.map { case (dataType, index) =>
StructField(s"col$index", dataType, true)
}
val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data)))
val df = spark.createDataFrame(rdd, StructType(schemas))
val row = df.persist.take(1).apply(0)
checkAnswer(df, row)
}
}
test("access decimal/string-type columns in CachedBatch without whole stage codegen") {
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "2") {
val data = Seq(BigDecimal(Long.MaxValue.toString + ".12345"),
new java.math.BigDecimal("1234567890.12345"),
new java.math.BigDecimal("1.23456"),
"test123"
)
val schemas = Seq(
StructField("col0", DecimalType(25, 5), true),
StructField("col1", DecimalType(15, 5), true),
StructField("col2", DecimalType(6, 5), true),
StructField("col3", StringType, true)
)
val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data)))
val df = spark.createDataFrame(rdd, StructType(schemas))
val row = df.persist.take(1).apply(0)
checkAnswer(df, row)
}
}
test("access non-primitive-type columns in CachedBatch without whole stage codegen") {
withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "2") {
val data = Seq((1 to 10).toArray,
Array(Array(10, 11), Array(100, 111, 123)),
Map("key1" -> 111, "key2" -> 222),
Row(1.25.toFloat, Seq(true, false, null))
)
val struct = StructType(StructField("f1", FloatType, false) ::
StructField("f2", ArrayType(BooleanType), true) :: Nil)
val schemas = Seq(
StructField("col0", ArrayType(IntegerType), true),
StructField("col1", ArrayType(ArrayType(IntegerType)), true),
StructField("col2", MapType(StringType, IntegerType), true),
StructField("col3", struct, true)
)
val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data)))
val df = spark.createDataFrame(rdd, StructType(schemas))
val row = df.persist.take(1).apply(0)
checkAnswer(df, row)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment