Skip to content
Snippets Groups Projects
Commit 6ebe419f authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-8114][SQL] Remove some wildcard import on TestSQLContext._ cont'd.

Fixed the following packages:
sql.columnar
sql.jdbc
sql.json
sql.parquet

Author: Reynold Xin <rxin@databricks.com>

Closes #6667 from rxin/testsqlcontext_wildcard and squashes the following commits:

134a776 [Reynold Xin] Fixed compilation break.
6da7b69 [Reynold Xin] [SPARK-8114][SQL] Remove some wildcard import on TestSQLContext._ cont'd.
parent 356a4a9b
No related branches found
No related tags found
No related merge requests found
Showing
with 234 additions and 245 deletions
......@@ -21,8 +21,6 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, TestData}
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
......@@ -31,8 +29,12 @@ class InMemoryColumnarQuerySuite extends QueryTest {
// Make sure the tables are loaded.
TestData
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.{logicalPlanToSparkQuery, sql}
test("simple columnar query") {
val plan = executePlan(testData.logicalPlan).executedPlan
val plan = ctx.executePlan(testData.logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().toSeq)
......@@ -40,16 +42,16 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("default size avoids broadcast") {
// TODO: Improve this test when we have better statistics
sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
.toDF().registerTempTable("sizeTst")
cacheTable("sizeTst")
ctx.cacheTable("sizeTst")
assert(
table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes >
conf.autoBroadcastJoinThreshold)
ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes >
ctx.conf.autoBroadcastJoinThreshold)
}
test("projection") {
val plan = executePlan(testData.select('value, 'key).logicalPlan).executedPlan
val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().map {
......@@ -58,7 +60,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
}
test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") {
val plan = executePlan(testData.logicalPlan).executedPlan
val plan = ctx.executePlan(testData.logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
checkAnswer(scan, testData.collect().toSeq)
......@@ -70,7 +72,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT * FROM repeatedData"),
repeatedData.collect().toSeq.map(Row.fromTuple))
cacheTable("repeatedData")
ctx.cacheTable("repeatedData")
checkAnswer(
sql("SELECT * FROM repeatedData"),
......@@ -82,7 +84,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT * FROM nullableRepeatedData"),
nullableRepeatedData.collect().toSeq.map(Row.fromTuple))
cacheTable("nullableRepeatedData")
ctx.cacheTable("nullableRepeatedData")
checkAnswer(
sql("SELECT * FROM nullableRepeatedData"),
......@@ -94,7 +96,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT time FROM timestamps"),
timestamps.collect().toSeq.map(Row.fromTuple))
cacheTable("timestamps")
ctx.cacheTable("timestamps")
checkAnswer(
sql("SELECT time FROM timestamps"),
......@@ -106,7 +108,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
sql("SELECT * FROM withEmptyParts"),
withEmptyParts.collect().toSeq.map(Row.fromTuple))
cacheTable("withEmptyParts")
ctx.cacheTable("withEmptyParts")
checkAnswer(
sql("SELECT * FROM withEmptyParts"),
......@@ -155,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest {
// Create a RDD for the schema
val rdd =
sparkContext.parallelize((1 to 100), 10).map { i =>
ctx.sparkContext.parallelize((1 to 100), 10).map { i =>
Row(
s"str${i}: test cache.",
s"binary${i}: test cache.".getBytes("UTF-8"),
......@@ -175,18 +177,18 @@ class InMemoryColumnarQuerySuite extends QueryTest {
(0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap,
Row((i - 0.25).toFloat, Seq(true, false, null)))
}
createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types")
// Cache the table.
sql("cache table InMemoryCache_different_data_types")
// Make sure the table is indeed cached.
val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan
val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan
assert(
isCached("InMemoryCache_different_data_types"),
ctx.isCached("InMemoryCache_different_data_types"),
"InMemoryCache_different_data_types should be cached.")
// Issue a query and check the results.
checkAnswer(
sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"),
table("InMemoryCache_different_data_types").collect())
dropTempTable("InMemoryCache_different_data_types")
ctx.table("InMemoryCache_different_data_types").collect())
ctx.dropTempTable("InMemoryCache_different_data_types")
}
}
......@@ -21,40 +21,42 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
val originalColumnBatchSize = conf.columnBatchSize
val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize
private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning
override protected def beforeAll(): Unit = {
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
val pruningData = sparkContext.makeRDD((1 to 100).map { key =>
val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key =>
val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
TestData(key, string)
}, 5).toDF()
pruningData.registerTempTable("pruningData")
// Enable in-memory partition pruning
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
// Enable in-memory table scan accumulators
setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
}
override protected def afterAll(): Unit = {
setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
}
before {
cacheTable("pruningData")
ctx.cacheTable("pruningData")
}
after {
uncacheTable("pruningData")
ctx.uncacheTable("pruningData")
}
// Comparisons
......@@ -108,7 +110,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
expectedQueryResult: => Seq[Int]): Unit = {
test(query) {
val df = sql(query)
val df = ctx.sql(query)
val queryExecution = df.queryExecution
assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") {
......
......@@ -21,13 +21,11 @@ import java.math.BigDecimal
import java.sql.DriverManager
import java.util.{Calendar, GregorianCalendar, Properties}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test._
import org.apache.spark.sql.types._
import org.h2.jdbc.JdbcSQLException
import org.scalatest.BeforeAndAfter
import TestSQLContext._
import TestSQLContext.implicits._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
val url = "jdbc:h2:mem:testdb0"
......@@ -37,12 +35,16 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte)
val testH2Dialect = new JdbcDialect {
def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2")
override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
Some(StringType)
}
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.sql
before {
Class.forName("org.h2.Driver")
// Extra properties that will be specified for our database. We need these to test
......@@ -253,26 +255,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
}
test("Basic API") {
assert(TestSQLContext.read.jdbc(
assert(ctx.read.jdbc(
urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3)
}
test("Basic API with FetchSize") {
val properties = new Properties
properties.setProperty("fetchSize", "2")
assert(TestSQLContext.read.jdbc(
assert(ctx.read.jdbc(
urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3)
}
test("Partitioning via JDBCPartitioningInfo API") {
assert(
TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties)
ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties)
.collect().length === 3)
}
test("Partitioning via list-of-where-clauses API") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
assert(TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties)
assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties)
.collect().length === 3)
}
......@@ -328,9 +330,9 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
}
test("test DATE types") {
val rows = TestSQLContext.read.jdbc(
val rows = ctx.read.jdbc(
urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
val cachedRows = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
.cache().collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
assert(rows(1).getAs[java.sql.Date](1) === null)
......@@ -338,9 +340,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
}
test("test DATE types in cache") {
val rows =
TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
.cache().registerTempTable("mycached_date")
val cachedRows = sql("select * from mycached_date").collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
......@@ -348,7 +349,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
}
test("test types for null value") {
val rows = TestSQLContext.read.jdbc(
val rows = ctx.read.jdbc(
urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect()
assert((0 to 14).forall(i => rows(0).isNullAt(i)))
}
......@@ -395,10 +396,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
test("Remap types via JdbcDialects") {
JdbcDialects.registerDialect(testH2Dialect)
val df = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
assert(df.schema.filter(
_.dataType != org.apache.spark.sql.types.StringType
).isEmpty)
val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty)
val rows = df.collect()
assert(rows(0).get(0).isInstanceOf[String])
assert(rows(0).get(1).isInstanceOf[String])
......@@ -419,7 +418,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
test("Aggregated dialects") {
val agg = new AggregatedDialect(List(new JdbcDialect {
def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:")
override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:")
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
if (sqlType % 2 == 0) {
......@@ -430,8 +429,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
}, testH2Dialect))
assert(agg.canHandle("jdbc:h2:xxx"))
assert(!agg.canHandle("jdbc:h2"))
assert(agg.getCatalystType(0, "", 1, null) == Some(LongType))
assert(agg.getCatalystType(1, "", 1, null) == Some(StringType))
assert(agg.getCatalystType(0, "", 1, null) === Some(LongType))
assert(agg.getCatalystType(1, "", 1, null) === Some(StringType))
}
}
......@@ -24,7 +24,6 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{SaveMode, Row}
import org.apache.spark.sql.test._
import org.apache.spark.sql.types._
class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
......@@ -37,6 +36,10 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
properties.setProperty("password", "testPass")
properties.setProperty("rowId", "false")
private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._
import ctx.sql
before {
Class.forName("org.h2.Driver")
conn = DriverManager.getConnection(url)
......@@ -54,14 +57,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
"create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
conn1.commit()
TestSQLContext.sql(
ctx.sql(
s"""
|CREATE TEMPORARY TABLE PEOPLE
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
TestSQLContext.sql(
ctx.sql(
s"""
|CREATE TEMPORARY TABLE PEOPLE1
|USING org.apache.spark.sql.jdbc
......@@ -74,66 +77,64 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
conn1.close()
}
val sc = TestSQLContext.sparkContext
private lazy val sc = ctx.sparkContext
val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222))
val arr1x2 = Array[Row](Row.apply("fred", 3))
val schema2 = StructType(
private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222))
private lazy val arr1x2 = Array[Row](Row.apply("fred", 3))
private lazy val schema2 = StructType(
StructField("name", StringType) ::
StructField("id", IntegerType) :: Nil)
val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2))
val schema3 = StructType(
private lazy val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2))
private lazy val schema3 = StructType(
StructField("name", StringType) ::
StructField("id", IntegerType) ::
StructField("seq", IntegerType) :: Nil)
test("Basic CREATE") {
val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties)
assert(2 == TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count)
assert(2 ==
TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length)
assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count)
assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length)
}
test("CREATE with overwrite") {
val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3)
val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2)
df.write.jdbc(url1, "TEST.DROPTEST", properties)
assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count)
assert(3 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count)
assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties)
assert(1 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count)
assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count)
assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
}
test("CREATE then INSERT to append") {
val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2)
df.write.jdbc(url, "TEST.APPENDTEST", new Properties)
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties)
assert(3 == TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count)
assert(2 ==
TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length)
assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count)
assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length)
}
test("CREATE then INSERT to truncate") {
val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2)
df.write.jdbc(url1, "TEST.TRUNCATETEST", properties)
df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties)
assert(1 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count)
assert(2 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count)
assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
}
test("Incompatible INSERT to append") {
val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2)
val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3)
df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties)
intercept[org.apache.spark.SparkException] {
......@@ -142,15 +143,15 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
}
test("INSERT to JDBC Datasource") {
TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
test("INSERT to JDBC Datasource with overwrite") {
TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
}
......@@ -23,21 +23,19 @@ import java.sql.{Date, Timestamp}
import com.fasterxml.jackson.core.JsonFactory
import org.scalactic.Tolerance._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.json.InferSchema.compatibleType
import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
import org.apache.spark.util.Utils
class JsonSuite extends QueryTest {
import org.apache.spark.sql.json.TestJsonData._
class JsonSuite extends QueryTest with TestJsonData {
TestJsonData
protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.sql
import ctx.implicits._
test("Type promotion") {
def checkTypePromotion(expected: Any, actual: Any) {
......@@ -214,7 +212,7 @@ class JsonSuite extends QueryTest {
}
test("Complex field and type inferring with null in sampling") {
val jsonDF = read.json(jsonNullStruct)
val jsonDF = ctx.read.json(jsonNullStruct)
val expectedSchema = StructType(
StructField("headers", StructType(
StructField("Charset", StringType, true) ::
......@@ -233,7 +231,7 @@ class JsonSuite extends QueryTest {
}
test("Primitive field and type inferring") {
val jsonDF = read.json(primitiveFieldAndType)
val jsonDF = ctx.read.json(primitiveFieldAndType)
val expectedSchema = StructType(
StructField("bigInteger", DecimalType.Unlimited, true) ::
......@@ -261,7 +259,7 @@ class JsonSuite extends QueryTest {
}
test("Complex field and type inferring") {
val jsonDF = read.json(complexFieldAndType1)
val jsonDF = ctx.read.json(complexFieldAndType1)
val expectedSchema = StructType(
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) ::
......@@ -360,7 +358,7 @@ class JsonSuite extends QueryTest {
}
test("GetField operation on complex data type") {
val jsonDF = read.json(complexFieldAndType1)
val jsonDF = ctx.read.json(complexFieldAndType1)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
......@@ -376,7 +374,7 @@ class JsonSuite extends QueryTest {
}
test("Type conflict in primitive field values") {
val jsonDF = read.json(primitiveFieldValueTypeConflict)
val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict)
val expectedSchema = StructType(
StructField("num_bool", StringType, true) ::
......@@ -450,7 +448,7 @@ class JsonSuite extends QueryTest {
}
ignore("Type conflict in primitive field values (Ignored)") {
val jsonDF = read.json(primitiveFieldValueTypeConflict)
val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict)
jsonDF.registerTempTable("jsonTable")
// Right now, the analyzer does not promote strings in a boolean expression.
......@@ -503,7 +501,7 @@ class JsonSuite extends QueryTest {
}
test("Type conflict in complex field values") {
val jsonDF = read.json(complexFieldValueTypeConflict)
val jsonDF = ctx.read.json(complexFieldValueTypeConflict)
val expectedSchema = StructType(
StructField("array", ArrayType(LongType, true), true) ::
......@@ -527,7 +525,7 @@ class JsonSuite extends QueryTest {
}
test("Type conflict in array elements") {
val jsonDF = read.json(arrayElementTypeConflict)
val jsonDF = ctx.read.json(arrayElementTypeConflict)
val expectedSchema = StructType(
StructField("array1", ArrayType(StringType, true), true) ::
......@@ -555,7 +553,7 @@ class JsonSuite extends QueryTest {
}
test("Handling missing fields") {
val jsonDF = read.json(missingFields)
val jsonDF = ctx.read.json(missingFields)
val expectedSchema = StructType(
StructField("a", BooleanType, true) ::
......@@ -574,8 +572,9 @@ class JsonSuite extends QueryTest {
val dir = Utils.createTempDir()
dir.delete()
val path = dir.getCanonicalPath
sparkContext.parallelize(1 to 100).map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
val jsonDF = read.option("samplingRatio", "0.49").json(path)
ctx.sparkContext.parallelize(1 to 100)
.map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path)
val analyzed = jsonDF.queryExecution.analyzed
assert(
......@@ -590,7 +589,7 @@ class JsonSuite extends QueryTest {
val schema = StructType(StructField("a", LongType, true) :: Nil)
val logicalRelation =
read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation]
ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation]
val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation]
assert(relationWithSchema.path === Some(path))
assert(relationWithSchema.schema === schema)
......@@ -602,7 +601,7 @@ class JsonSuite extends QueryTest {
dir.delete()
val path = dir.getCanonicalPath
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
val jsonDF = read.json(path)
val jsonDF = ctx.read.json(path)
val expectedSchema = StructType(
StructField("bigInteger", DecimalType.Unlimited, true) ::
......@@ -671,7 +670,7 @@ class JsonSuite extends QueryTest {
StructField("null", StringType, true) ::
StructField("string", StringType, true) :: Nil)
val jsonDF1 = read.schema(schema).json(path)
val jsonDF1 = ctx.read.schema(schema).json(path)
assert(schema === jsonDF1.schema)
......@@ -688,7 +687,7 @@ class JsonSuite extends QueryTest {
"this is a simple string.")
)
val jsonDF2 = read.schema(schema).json(primitiveFieldAndType)
val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType)
assert(schema === jsonDF2.schema)
......@@ -709,7 +708,7 @@ class JsonSuite extends QueryTest {
test("Applying schemas with MapType") {
val schemaWithSimpleMap = StructType(
StructField("map", MapType(StringType, IntegerType, true), false) :: Nil)
val jsonWithSimpleMap = read.schema(schemaWithSimpleMap).json(mapType1)
val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1)
jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap")
......@@ -737,7 +736,7 @@ class JsonSuite extends QueryTest {
val schemaWithComplexMap = StructType(
StructField("map", MapType(StringType, innerStruct, true), false) :: Nil)
val jsonWithComplexMap = read.schema(schemaWithComplexMap).json(mapType2)
val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2)
jsonWithComplexMap.registerTempTable("jsonWithComplexMap")
......@@ -763,7 +762,7 @@ class JsonSuite extends QueryTest {
}
test("SPARK-2096 Correctly parse dot notations") {
val jsonDF = read.json(complexFieldAndType2)
val jsonDF = ctx.read.json(complexFieldAndType2)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
......@@ -781,7 +780,7 @@ class JsonSuite extends QueryTest {
}
test("SPARK-3390 Complex arrays") {
val jsonDF = read.json(complexFieldAndType2)
val jsonDF = ctx.read.json(complexFieldAndType2)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
......@@ -804,7 +803,7 @@ class JsonSuite extends QueryTest {
}
test("SPARK-3308 Read top level JSON arrays") {
val jsonDF = read.json(jsonArray)
val jsonDF = ctx.read.json(jsonArray)
jsonDF.registerTempTable("jsonTable")
checkAnswer(
......@@ -822,10 +821,10 @@ class JsonSuite extends QueryTest {
test("Corrupt records") {
// Test if we can query corrupt records.
val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord
TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord
ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
val jsonDF = read.json(corruptRecords)
val jsonDF = ctx.read.json(corruptRecords)
jsonDF.registerTempTable("jsonTable")
val schema = StructType(
......@@ -875,11 +874,11 @@ class JsonSuite extends QueryTest {
Row("]") :: Nil
)
TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
}
test("SPARK-4068: nulls in arrays") {
val jsonDF = read.json(nullsInArrays)
val jsonDF = ctx.read.json(nullsInArrays)
jsonDF.registerTempTable("jsonTable")
val schema = StructType(
......@@ -925,7 +924,7 @@ class JsonSuite extends QueryTest {
Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5)
}
val df1 = createDataFrame(rowRDD1, schema1)
val df1 = ctx.createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
val df2 = df1.toDF
val result = df2.toJSON.collect()
......@@ -948,7 +947,7 @@ class JsonSuite extends QueryTest {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
val df3 = createDataFrame(rowRDD2, schema2)
val df3 = ctx.createDataFrame(rowRDD2, schema2)
df3.registerTempTable("applySchema2")
val df4 = df3.toDF
val result2 = df4.toJSON.collect()
......@@ -956,8 +955,8 @@ class JsonSuite extends QueryTest {
assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}")
val jsonDF = read.json(primitiveFieldAndType)
val primTable = read.json(jsonDF.toJSON)
val jsonDF = ctx.read.json(primitiveFieldAndType)
val primTable = ctx.read.json(jsonDF.toJSON)
primTable.registerTempTable("primativeTable")
checkAnswer(
sql("select * from primativeTable"),
......@@ -969,8 +968,8 @@ class JsonSuite extends QueryTest {
"this is a simple string.")
)
val complexJsonDF = read.json(complexFieldAndType1)
val compTable = read.json(complexJsonDF.toJSON)
val complexJsonDF = ctx.read.json(complexFieldAndType1)
val compTable = ctx.read.json(complexJsonDF.toJSON)
compTable.registerTempTable("complexTable")
// Access elements of a primitive array.
checkAnswer(
......@@ -1074,29 +1073,29 @@ class JsonSuite extends QueryTest {
}
test("SPARK-7565 MapType in JsonRDD") {
val useStreaming = getConf(SQLConf.USE_JACKSON_STREAMING_API, "true")
val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord
TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
val useStreaming = ctx.getConf(SQLConf.USE_JACKSON_STREAMING_API, "true")
val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord
ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed")
val schemaWithSimpleMap = StructType(
StructField("map", MapType(StringType, IntegerType, true), false) :: Nil)
try{
for (useStreaming <- List("true", "false")) {
setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
val temp = Utils.createTempDir().getPath
val df = read.schema(schemaWithSimpleMap).json(mapType1)
val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1)
df.write.mode("overwrite").parquet(temp)
// order of MapType is not defined
assert(read.parquet(temp).count() == 5)
assert(ctx.read.parquet(temp).count() == 5)
val df2 = read.json(corruptRecords)
val df2 = ctx.read.json(corruptRecords)
df2.write.mode("overwrite").parquet(temp)
checkAnswer(read.parquet(temp), df2.collect())
checkAnswer(ctx.read.parquet(temp), df2.collect())
}
} finally {
setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming)
ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord)
}
}
......
......@@ -17,12 +17,15 @@
package org.apache.spark.sql.json
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
object TestJsonData {
trait TestJsonData {
val primitiveFieldAndType =
TestSQLContext.sparkContext.parallelize(
protected def ctx: SQLContext
def primitiveFieldAndType: RDD[String] =
ctx.sparkContext.parallelize(
"""{"string":"this is a simple string.",
"integer":10,
"long":21474836470,
......@@ -32,8 +35,8 @@ object TestJsonData {
"null":null
}""" :: Nil)
val primitiveFieldValueTypeConflict =
TestSQLContext.sparkContext.parallelize(
def primitiveFieldValueTypeConflict: RDD[String] =
ctx.sparkContext.parallelize(
"""{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1,
"num_bool":true, "num_str":13.1, "str_bool":"str1"}""" ::
"""{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null,
......@@ -43,15 +46,15 @@ object TestJsonData {
"""{"num_num_1":21474836570, "num_num_2":1.1, "num_num_3": 21474836470,
"num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil)
val jsonNullStruct =
TestSQLContext.sparkContext.parallelize(
def jsonNullStruct: RDD[String] =
ctx.sparkContext.parallelize(
"""{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":{}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":""}""" ::
"""{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil)
val complexFieldValueTypeConflict =
TestSQLContext.sparkContext.parallelize(
def complexFieldValueTypeConflict: RDD[String] =
ctx.sparkContext.parallelize(
"""{"num_struct":11, "str_array":[1, 2, 3],
"array":[], "struct_array":[], "struct": {}}""" ::
"""{"num_struct":{"field":false}, "str_array":null,
......@@ -61,23 +64,23 @@ object TestJsonData {
"""{"num_struct":{}, "str_array":["str1", "str2", 33],
"array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil)
val arrayElementTypeConflict =
TestSQLContext.sparkContext.parallelize(
def arrayElementTypeConflict: RDD[String] =
ctx.sparkContext.parallelize(
"""{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}],
"array2": [{"field":214748364700}, {"field":1}]}""" ::
"""{"array3": [{"field":"str"}, {"field":1}]}""" ::
"""{"array3": [1, 2, 3]}""" :: Nil)
val missingFields =
TestSQLContext.sparkContext.parallelize(
def missingFields: RDD[String] =
ctx.sparkContext.parallelize(
"""{"a":true}""" ::
"""{"b":21474836470}""" ::
"""{"c":[33, 44]}""" ::
"""{"d":{"field":true}}""" ::
"""{"e":"str"}""" :: Nil)
val complexFieldAndType1 =
TestSQLContext.sparkContext.parallelize(
def complexFieldAndType1: RDD[String] =
ctx.sparkContext.parallelize(
"""{"struct":{"field1": true, "field2": 92233720368547758070},
"structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]},
"arrayOfString":["str1", "str2"],
......@@ -92,8 +95,8 @@ object TestJsonData {
"arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]]
}""" :: Nil)
val complexFieldAndType2 =
TestSQLContext.sparkContext.parallelize(
def complexFieldAndType2: RDD[String] =
ctx.sparkContext.parallelize(
"""{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}],
"complexArrayOfStruct": [
{
......@@ -146,16 +149,16 @@ object TestJsonData {
]]
}""" :: Nil)
val mapType1 =
TestSQLContext.sparkContext.parallelize(
def mapType1: RDD[String] =
ctx.sparkContext.parallelize(
"""{"map": {"a": 1}}""" ::
"""{"map": {"b": 2}}""" ::
"""{"map": {"c": 3}}""" ::
"""{"map": {"c": 1, "d": 4}}""" ::
"""{"map": {"e": null}}""" :: Nil)
val mapType2 =
TestSQLContext.sparkContext.parallelize(
def mapType2: RDD[String] =
ctx.sparkContext.parallelize(
"""{"map": {"a": {"field1": [1, 2, 3, null]}}}""" ::
"""{"map": {"b": {"field2": 2}}}""" ::
"""{"map": {"c": {"field1": [], "field2": 4}}}""" ::
......@@ -163,22 +166,22 @@ object TestJsonData {
"""{"map": {"e": null}}""" ::
"""{"map": {"f": {"field1": null}}}""" :: Nil)
val nullsInArrays =
TestSQLContext.sparkContext.parallelize(
def nullsInArrays: RDD[String] =
ctx.sparkContext.parallelize(
"""{"field1":[[null], [[["Test"]]]]}""" ::
"""{"field2":[null, [{"Test":1}]]}""" ::
"""{"field3":[[null], [{"Test":"2"}]]}""" ::
"""{"field4":[[null, [1,2,3]]]}""" :: Nil)
val jsonArray =
TestSQLContext.sparkContext.parallelize(
def jsonArray: RDD[String] =
ctx.sparkContext.parallelize(
"""[{"a":"str_a_1"}]""" ::
"""[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""[]""" :: Nil)
val corruptRecords =
TestSQLContext.sparkContext.parallelize(
def corruptRecords: RDD[String] =
ctx.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a":1, b:2}""" ::
......@@ -186,6 +189,5 @@ object TestJsonData {
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""]""" :: Nil)
val empty =
TestSQLContext.sparkContext.parallelize(Seq[String]())
def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]())
}
......@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf}
......@@ -42,7 +41,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf}
* data type is nullable.
*/
class ParquetFilterSuiteBase extends QueryTest with ParquetTest {
val sqlContext = TestSQLContext
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
private def checkFilterPredicate(
df: DataFrame,
......@@ -312,7 +311,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest {
}
class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll {
val originalConf = sqlContext.conf.parquetUseDataSourceApi
lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
......@@ -341,7 +340,7 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA
}
class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll {
val originalConf = sqlContext.conf.parquetUseDataSourceApi
lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
......
......@@ -36,9 +36,6 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.util.DateUtils
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode}
......@@ -66,9 +63,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS
* A test suite that tests basic Parquet I/O.
*/
class ParquetIOSuiteBase extends QueryTest with ParquetTest {
val sqlContext = TestSQLContext
import sqlContext.implicits.localSeqToDataFrameHolder
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
/**
* Writes `data` to a Parquet file, reads it back and check file contents.
......@@ -104,7 +100,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
test("fixed-length decimals") {
def makeDecimalRDD(decimal: DecimalType): DataFrame =
sparkContext
sqlContext.sparkContext
.parallelize(0 to 1000)
.map(i => Tuple1(i / 100.0))
.toDF()
......@@ -115,7 +111,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withTempPath { dir =>
val data = makeDecimalRDD(DecimalType(precision, scale))
data.write.parquet(dir.getCanonicalPath)
checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq)
checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq)
}
}
......@@ -123,7 +119,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
intercept[Throwable] {
withTempPath { dir =>
makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath)
read.parquet(dir.getCanonicalPath).collect()
sqlContext.read.parquet(dir.getCanonicalPath).collect()
}
}
......@@ -131,14 +127,14 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
intercept[Throwable] {
withTempPath { dir =>
makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath)
read.parquet(dir.getCanonicalPath).collect()
sqlContext.read.parquet(dir.getCanonicalPath).collect()
}
}
}
test("date type") {
def makeDateRDD(): DataFrame =
sparkContext
sqlContext.sparkContext
.parallelize(0 to 1000)
.map(i => Tuple1(DateUtils.toJavaDate(i)))
.toDF()
......@@ -147,7 +143,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withTempPath { dir =>
val data = makeDateRDD()
data.write.parquet(dir.getCanonicalPath)
checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq)
checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq)
}
}
......@@ -236,7 +232,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
def checkCompressionCodec(codec: CompressionCodecName): Unit = {
withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) {
withParquetFile(data) { path =>
assertResult(conf.parquetCompressionCodec.toUpperCase) {
assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) {
compressionCodecFor(path)
}
}
......@@ -244,7 +240,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
}
// Checks default compression codec
checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec))
checkCompressionCodec(CompressionCodecName.fromConf(sqlContext.conf.parquetCompressionCodec))
checkCompressionCodec(CompressionCodecName.UNCOMPRESSED)
checkCompressionCodec(CompressionCodecName.GZIP)
......@@ -283,7 +279,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "part-r-0.parquet")
makeRawParquetFile(path)
checkAnswer(read.parquet(path.toString), (0 until 10).map { i =>
checkAnswer(sqlContext.read.parquet(path.toString), (0 until 10).map { i =>
Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble)
})
}
......@@ -312,7 +308,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withParquetFile((1 to 10).map(i => (i, i.toString))) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file)
checkAnswer(read.parquet(file), newData.map(Row.fromTuple))
checkAnswer(sqlContext.read.parquet(file), newData.map(Row.fromTuple))
}
}
......@@ -321,7 +317,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withParquetFile(data) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file)
checkAnswer(read.parquet(file), data.map(Row.fromTuple))
checkAnswer(sqlContext.read.parquet(file), data.map(Row.fromTuple))
}
}
......@@ -341,7 +337,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
withParquetFile(data) { file =>
val newData = (11 to 20).map(i => (i, i.toString))
newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file)
checkAnswer(read.parquet(file), (data ++ newData).map(Row.fromTuple))
checkAnswer(sqlContext.read.parquet(file), (data ++ newData).map(Row.fromTuple))
}
}
......@@ -369,11 +365,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
val path = new Path(location.getCanonicalPath)
ParquetFileWriter.writeMetadataFile(
sparkContext.hadoopConfiguration,
sqlContext.sparkContext.hadoopConfiguration,
path,
new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil)
assertResult(read.parquet(path.toString).schema) {
assertResult(sqlContext.read.parquet(path.toString).schema) {
StructType(
StructField("a", BooleanType, nullable = false) ::
StructField("b", IntegerType, nullable = false) ::
......@@ -406,7 +402,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest {
}
class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll {
val originalConf = sqlContext.conf.parquetUseDataSourceApi
private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
......@@ -430,7 +426,7 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA
}
class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll {
val originalConf = sqlContext.conf.parquetUseDataSourceApi
private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
......
......@@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.parquet
import java.io.File
......@@ -28,7 +29,6 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.sources.PartitioningUtils._
import org.apache.spark.sql.sources.{LogicalRelation, Partition, PartitionSpec}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, QueryTest, Row, SQLContext}
......@@ -39,10 +39,10 @@ case class ParquetData(intField: Int, stringField: String)
case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
override val sqlContext: SQLContext = TestSQLContext
import sqlContext._
override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
import sqlContext.sql
val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__"
......@@ -190,8 +190,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
// Introduce _temporary dir to the base dir the robustness of the schema discovery process.
new File(base.getCanonicalPath, "_temporary").mkdir()
println("load the partitioned table")
read.parquet(base.getCanonicalPath).registerTempTable("t")
sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t")
withTempTable("t") {
checkAnswer(
......@@ -238,7 +237,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
read.parquet(base.getCanonicalPath).registerTempTable("t")
sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t")
withTempTable("t") {
checkAnswer(
......@@ -286,7 +285,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath)
val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath)
parquetRelation.registerTempTable("t")
withTempTable("t") {
......@@ -326,7 +325,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
}
val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath)
val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath)
parquetRelation.registerTempTable("t")
withTempTable("t") {
......@@ -358,7 +357,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
(1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"),
makePartitionDir(base, defaultPartitionName, "pi" -> 2))
read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath).registerTempTable("t")
sqlContext.read.format("parquet").load(base.getCanonicalPath).registerTempTable("t")
withTempTable("t") {
checkAnswer(
......@@ -371,7 +370,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
test("SPARK-7749 Non-partitioned table should have empty partition spec") {
withTempPath { dir =>
(1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath)
val queryExecution = read.parquet(dir.getCanonicalPath).queryExecution
val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution
queryExecution.analyzed.collectFirst {
case LogicalRelation(relation: ParquetRelation2) =>
assert(relation.partitionSpec === PartitionSpec.emptySpec)
......@@ -385,7 +384,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
withTempPath { dir =>
val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s")
df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath)
checkAnswer(read.parquet(dir.getCanonicalPath), df.collect())
checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), df.collect())
}
}
......@@ -425,12 +424,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
}
val schema = StructType(partitionColumns :+ StructField(s"i", StringType))
val df = createDataFrame(sparkContext.parallelize(row :: Nil), schema)
val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema)
withTempPath { dir =>
df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString)
val fields = schema.map(f => Column(f.name).cast(f.dataType))
checkAnswer(read.load(dir.toString).select(fields: _*), row)
checkAnswer(sqlContext.read.load(dir.toString).select(fields: _*), row)
}
}
......@@ -446,7 +445,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store"))
Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar"))
checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df)
checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df)
}
}
}
......@@ -22,14 +22,14 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.types._
import org.apache.spark.sql.{SQLConf, QueryTest}
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
/**
* A test suite that tests various Parquet queries.
*/
class ParquetQuerySuiteBase extends QueryTest with ParquetTest {
val sqlContext = TestSQLContext
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._
import sqlContext.sql
test("simple select queries") {
withParquetTable((0 until 10).map(i => (i, i.toString)), "t") {
......@@ -40,22 +40,22 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest {
test("appending") {
val data = (0 until 10).map(i => (i, i.toString))
createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT INTO TABLE t SELECT * FROM tmp")
checkAnswer(table("t"), (data ++ data).map(Row.fromTuple))
checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
}
catalog.unregisterTable(Seq("tmp"))
sqlContext.catalog.unregisterTable(Seq("tmp"))
}
test("overwriting") {
val data = (0 until 10).map(i => (i, i.toString))
createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
checkAnswer(table("t"), data.map(Row.fromTuple))
checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple))
}
catalog.unregisterTable(Seq("tmp"))
sqlContext.catalog.unregisterTable(Seq("tmp"))
}
test("self-join") {
......@@ -118,7 +118,7 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest {
val schema = StructType(List(StructField("d", DecimalType(18, 0), false),
StructField("time", TimestampType, false)).toArray)
withTempPath { file =>
val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema)
val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema)
df.write.parquet(file.getCanonicalPath)
val df2 = sqlContext.read.parquet(file.getCanonicalPath)
checkAnswer(df2, df.collect().toSeq)
......@@ -127,7 +127,7 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest {
}
class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll {
val originalConf = sqlContext.conf.parquetUseDataSourceApi
private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
......@@ -139,7 +139,7 @@ class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAnd
}
class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll {
val originalConf = sqlContext.conf.parquetUseDataSourceApi
private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi
override protected def beforeAll(): Unit = {
sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
......
......@@ -24,11 +24,10 @@ import org.apache.parquet.schema.MessageTypeParser
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
class ParquetSchemaSuite extends SparkFunSuite with ParquetTest {
val sqlContext = TestSQLContext
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
/**
* Checks whether the reflected Parquet message type for product type `T` conforms `messageType`.
......
......@@ -33,8 +33,6 @@ import org.apache.spark.sql.{DataFrame, SaveMode}
* Especially, `Tuple1.apply` can be used to easily wrap a single type/value.
*/
private[sql] trait ParquetTest extends SQLTestUtils {
import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder}
import sqlContext.sparkContext
/**
* Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f`
......@@ -44,7 +42,7 @@ private[sql] trait ParquetTest extends SQLTestUtils {
(data: Seq[T])
(f: String => Unit): Unit = {
withTempPath { file =>
sparkContext.parallelize(data).toDF().write.parquet(file.getCanonicalPath)
sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
f(file.getCanonicalPath)
}
}
......@@ -75,7 +73,7 @@ private[sql] trait ParquetTest extends SQLTestUtils {
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
data: Seq[T], path: File): Unit = {
data.toDF().write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
}
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
......
......@@ -25,11 +25,9 @@ import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
trait SQLTestUtils {
val sqlContext: SQLContext
def sqlContext: SQLContext
import sqlContext.{conf, sparkContext}
protected def configuration = sparkContext.hadoopConfiguration
protected def configuration = sqlContext.sparkContext.hadoopConfiguration
/**
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL
......@@ -39,12 +37,12 @@ trait SQLTestUtils {
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val (keys, values) = pairs.unzip
val currentValues = keys.map(key => Try(conf.getConf(key)).toOption)
(keys, values).zipped.foreach(conf.setConf)
val currentValues = keys.map(key => Try(sqlContext.conf.getConf(key)).toOption)
(keys, values).zipped.foreach(sqlContext.conf.setConf)
try f finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => conf.setConf(key, value)
case (key, None) => conf.unsetConf(key)
case (key, Some(value)) => sqlContext.conf.setConf(key, value)
case (key, None) => sqlContext.conf.unsetConf(key)
}
}
}
......
......@@ -52,9 +52,6 @@ case class Contact(name: String, phone: String)
case class Person(name: String, age: Int, contacts: Seq[Contact])
class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
override val sqlContext = TestHive
import TestHive.read
def getTempFilePath(prefix: String, suffix: String = ""): File = {
val tempFile = File.createTempFile(prefix, suffix)
......@@ -69,7 +66,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest {
withOrcFile(data) { file =>
checkAnswer(
read.format("orc").load(file),
sqlContext.read.format("orc").load(file),
data.toDF().collect())
}
}
......
......@@ -22,13 +22,11 @@ import java.io.File
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql._
private[sql] trait OrcTest extends SQLTestUtils {
protected def hiveContext = sqlContext.asInstanceOf[HiveContext]
lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive
import sqlContext.sparkContext
import sqlContext.implicits._
......@@ -53,7 +51,7 @@ private[sql] trait OrcTest extends SQLTestUtils {
protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
(f: DataFrame => Unit): Unit = {
withOrcFile(data)(path => f(hiveContext.read.format("orc").load(path)))
withOrcFile(data)(path => f(sqlContext.read.format("orc").load(path)))
}
/**
......@@ -65,7 +63,7 @@ private[sql] trait OrcTest extends SQLTestUtils {
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withOrcDataFrame(data) { df =>
hiveContext.registerDataFrameAsTable(df, tableName)
sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment