diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 055453e688e7379bd8dc24564303afa8f247ce62..fa3b8144c086e5c6b00f5b6a9ecb85f5f47be39d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -21,8 +21,6 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.storage.StorageLevel.MEMORY_ONLY @@ -31,8 +29,12 @@ class InMemoryColumnarQuerySuite extends QueryTest { // Make sure the tables are loaded. TestData + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.{logicalPlanToSparkQuery, sql} + test("simple columnar query") { - val plan = executePlan(testData.logicalPlan).executedPlan + val plan = ctx.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -40,16 +42,16 @@ class InMemoryColumnarQuerySuite extends QueryTest { test("default size avoids broadcast") { // TODO: Improve this test when we have better statistics - sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) + ctx.sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)) .toDF().registerTempTable("sizeTst") - cacheTable("sizeTst") + ctx.cacheTable("sizeTst") assert( - table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > - conf.autoBroadcastJoinThreshold) + ctx.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + ctx.conf.autoBroadcastJoinThreshold) } test("projection") { - val plan = executePlan(testData.select('value, 'key).logicalPlan).executedPlan + val plan = ctx.executePlan(testData.select('value, 'key).logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().map { @@ -58,7 +60,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { } test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { - val plan = executePlan(testData.logicalPlan).executedPlan + val plan = ctx.executePlan(testData.logicalPlan).executedPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None) checkAnswer(scan, testData.collect().toSeq) @@ -70,7 +72,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM repeatedData"), repeatedData.collect().toSeq.map(Row.fromTuple)) - cacheTable("repeatedData") + ctx.cacheTable("repeatedData") checkAnswer( sql("SELECT * FROM repeatedData"), @@ -82,7 +84,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq.map(Row.fromTuple)) - cacheTable("nullableRepeatedData") + ctx.cacheTable("nullableRepeatedData") checkAnswer( sql("SELECT * FROM nullableRepeatedData"), @@ -94,7 +96,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT time FROM timestamps"), timestamps.collect().toSeq.map(Row.fromTuple)) - cacheTable("timestamps") + ctx.cacheTable("timestamps") checkAnswer( sql("SELECT time FROM timestamps"), @@ -106,7 +108,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM withEmptyParts"), withEmptyParts.collect().toSeq.map(Row.fromTuple)) - cacheTable("withEmptyParts") + ctx.cacheTable("withEmptyParts") checkAnswer( sql("SELECT * FROM withEmptyParts"), @@ -155,7 +157,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { // Create a RDD for the schema val rdd = - sparkContext.parallelize((1 to 100), 10).map { i => + ctx.sparkContext.parallelize((1 to 100), 10).map { i => Row( s"str${i}: test cache.", s"binary${i}: test cache.".getBytes("UTF-8"), @@ -175,18 +177,18 @@ class InMemoryColumnarQuerySuite extends QueryTest { (0 to i).map(j => s"map_key_$j" -> (Long.MaxValue - j)).toMap, Row((i - 0.25).toFloat, Seq(true, false, null))) } - createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") + ctx.createDataFrame(rdd, schema).registerTempTable("InMemoryCache_different_data_types") // Cache the table. sql("cache table InMemoryCache_different_data_types") // Make sure the table is indeed cached. - val tableScan = table("InMemoryCache_different_data_types").queryExecution.executedPlan + val tableScan = ctx.table("InMemoryCache_different_data_types").queryExecution.executedPlan assert( - isCached("InMemoryCache_different_data_types"), + ctx.isCached("InMemoryCache_different_data_types"), "InMemoryCache_different_data_types should be cached.") // Issue a query and check the results. checkAnswer( sql(s"SELECT DISTINCT ${allColumns} FROM InMemoryCache_different_data_types"), - table("InMemoryCache_different_data_types").collect()) - dropTempTable("InMemoryCache_different_data_types") + ctx.table("InMemoryCache_different_data_types").collect()) + ctx.dropTempTable("InMemoryCache_different_data_types") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index cda1b0992e36f2d96f83549535d7595303f94992..6545c6b314a4c1c600a903b4e9b07980c098e55b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -21,40 +21,42 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { - val originalColumnBatchSize = conf.columnBatchSize - val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning + + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize + private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning override protected def beforeAll(): Unit = { // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch - setConf(SQLConf.COLUMN_BATCH_SIZE, "10") + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, "10") - val pruningData = sparkContext.makeRDD((1 to 100).map { key => + val pruningData = ctx.sparkContext.makeRDD((1 to 100).map { key => val string = if (((key - 1) / 10) % 2 == 0) null else key.toString TestData(key, string) }, 5).toDF() pruningData.registerTempTable("pruningData") // Enable in-memory partition pruning - setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") // Enable in-memory table scan accumulators - setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") + ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true") } override protected def afterAll(): Unit = { - setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) - setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) + ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) + ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) } before { - cacheTable("pruningData") + ctx.cacheTable("pruningData") } after { - uncacheTable("pruningData") + ctx.uncacheTable("pruningData") } // Comparisons @@ -108,7 +110,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi expectedQueryResult: => Seq[Int]): Unit = { test(query) { - val df = sql(query) + val df = ctx.sql(query) val queryExecution = df.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index e20c66cb2f1d7e2d0501d70c6f00c15d2fc96b17..7931854db27c17897623b40c69c00ea141e98bcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -21,13 +21,11 @@ import java.math.BigDecimal import java.sql.DriverManager import java.util.{Calendar, GregorianCalendar, Properties} -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test._ -import org.apache.spark.sql.types._ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter -import TestSQLContext._ -import TestSQLContext.implicits._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb0" @@ -37,12 +35,16 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) val testH2Dialect = new JdbcDialect { - def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") + override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = Some(StringType) } + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.sql + before { Class.forName("org.h2.Driver") // Extra properties that will be specified for our database. We need these to test @@ -253,26 +255,26 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("Basic API") { - assert(TestSQLContext.read.jdbc( + assert(ctx.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) } test("Basic API with FetchSize") { val properties = new Properties properties.setProperty("fetchSize", "2") - assert(TestSQLContext.read.jdbc( + assert(ctx.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert( - TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + assert(ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) .collect().length === 3) } @@ -328,9 +330,9 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("test DATE types") { - val rows = TestSQLContext.read.jdbc( + val rows = ctx.read.jdbc( urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - val cachedRows = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val cachedRows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) @@ -338,9 +340,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("test DATE types in cache") { - val rows = - TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + ctx.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) @@ -348,7 +349,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { } test("test types for null value") { - val rows = TestSQLContext.read.jdbc( + val rows = ctx.read.jdbc( urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } @@ -395,10 +396,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) - val df = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) - assert(df.schema.filter( - _.dataType != org.apache.spark.sql.types.StringType - ).isEmpty) + val df = ctx.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) val rows = df.collect() assert(rows(0).get(0).isInstanceOf[String]) assert(rows(0).get(1).isInstanceOf[String]) @@ -419,7 +418,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { test("Aggregated dialects") { val agg = new AggregatedDialect(List(new JdbcDialect { - def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") + override def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = if (sqlType % 2 == 0) { @@ -430,8 +429,8 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) assert(!agg.canHandle("jdbc:h2")) - assert(agg.getCatalystType(0, "", 1, null) == Some(LongType)) - assert(agg.getCatalystType(1, "", 1, null) == Some(StringType)) + assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) + assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 2de8c1a6098e0f356c442684cddaed69c4f1a8ce..d949ef42267ec3fd68ce1872034f49b1ed76752b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -24,7 +24,6 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{SaveMode, Row} -import org.apache.spark.sql.test._ import org.apache.spark.sql.types._ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { @@ -37,6 +36,10 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { properties.setProperty("password", "testPass") properties.setProperty("rowId", "false") + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.sql + before { Class.forName("org.h2.Driver") conn = DriverManager.getConnection(url) @@ -54,14 +57,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { "create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn1.commit() - TestSQLContext.sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE PEOPLE |USING org.apache.spark.sql.jdbc |OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - TestSQLContext.sql( + ctx.sql( s""" |CREATE TEMPORARY TABLE PEOPLE1 |USING org.apache.spark.sql.jdbc @@ -74,66 +77,64 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { conn1.close() } - val sc = TestSQLContext.sparkContext + private lazy val sc = ctx.sparkContext - val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) - val arr1x2 = Array[Row](Row.apply("fred", 3)) - val schema2 = StructType( + private lazy val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222)) + private lazy val arr1x2 = Array[Row](Row.apply("fred", 3)) + private lazy val schema2 = StructType( StructField("name", StringType) :: StructField("id", IntegerType) :: Nil) - val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)) - val schema3 = StructType( + private lazy val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2)) + private lazy val schema3 = StructType( StructField("name", StringType) :: StructField("id", IntegerType) :: StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) - assert(2 == TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) - assert(2 == - TestSQLContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert(2 === ctx.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.DROPTEST", properties) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) - assert(1 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(1 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) df.write.jdbc(url, "TEST.APPENDTEST", new Properties) df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) - assert(3 == TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) - assert(2 == - TestSQLContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) + assert(3 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 === ctx.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) - assert(1 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + assert(1 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) - val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) + val df = ctx.createDataFrame(sc.parallelize(arr2x2), schema2) + val df2 = ctx.createDataFrame(sc.parallelize(arr2x3), schema3) df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { @@ -142,15 +143,15 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter { } test("INSERT to JDBC Datasource") { - TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { - TestSQLContext.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - TestSQLContext.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 == TestSQLContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") + ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index f8d62f9e7e02bc1a41ed8186129dc9aa1d8f692f..d889c7be17ce731b2294b2de3b62eda0cd08169c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -23,21 +23,19 @@ import java.sql.{Date, Timestamp} import com.fasterxml.jackson.core.JsonFactory import org.scalactic.Tolerance._ +import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.json.InferSchema.compatibleType import org.apache.spark.sql.sources.LogicalRelation -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.util.Utils -class JsonSuite extends QueryTest { - import org.apache.spark.sql.json.TestJsonData._ +class JsonSuite extends QueryTest with TestJsonData { - TestJsonData + protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.sql + import ctx.implicits._ test("Type promotion") { def checkTypePromotion(expected: Any, actual: Any) { @@ -214,7 +212,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring with null in sampling") { - val jsonDF = read.json(jsonNullStruct) + val jsonDF = ctx.read.json(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -233,7 +231,7 @@ class JsonSuite extends QueryTest { } test("Primitive field and type inferring") { - val jsonDF = read.json(primitiveFieldAndType) + val jsonDF = ctx.read.json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -261,7 +259,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring") { - val jsonDF = read.json(complexFieldAndType1) + val jsonDF = ctx.read.json(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: @@ -360,7 +358,7 @@ class JsonSuite extends QueryTest { } test("GetField operation on complex data type") { - val jsonDF = read.json(complexFieldAndType1) + val jsonDF = ctx.read.json(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -376,7 +374,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in primitive field values") { - val jsonDF = read.json(primitiveFieldValueTypeConflict) + val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -450,7 +448,7 @@ class JsonSuite extends QueryTest { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonDF = read.json(primitiveFieldValueTypeConflict) + val jsonDF = ctx.read.json(primitiveFieldValueTypeConflict) jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expression. @@ -503,7 +501,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in complex field values") { - val jsonDF = read.json(complexFieldValueTypeConflict) + val jsonDF = ctx.read.json(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(LongType, true), true) :: @@ -527,7 +525,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in array elements") { - val jsonDF = read.json(arrayElementTypeConflict) + val jsonDF = ctx.read.json(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -555,7 +553,7 @@ class JsonSuite extends QueryTest { } test("Handling missing fields") { - val jsonDF = read.json(missingFields) + val jsonDF = ctx.read.json(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -574,8 +572,9 @@ class JsonSuite extends QueryTest { val dir = Utils.createTempDir() dir.delete() val path = dir.getCanonicalPath - sparkContext.parallelize(1 to 100).map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) - val jsonDF = read.option("samplingRatio", "0.49").json(path) + ctx.sparkContext.parallelize(1 to 100) + .map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path) + val jsonDF = ctx.read.option("samplingRatio", "0.49").json(path) val analyzed = jsonDF.queryExecution.analyzed assert( @@ -590,7 +589,7 @@ class JsonSuite extends QueryTest { val schema = StructType(StructField("a", LongType, true) :: Nil) val logicalRelation = - read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] + ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation] val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation] assert(relationWithSchema.path === Some(path)) assert(relationWithSchema.schema === schema) @@ -602,7 +601,7 @@ class JsonSuite extends QueryTest { dir.delete() val path = dir.getCanonicalPath primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonDF = read.json(path) + val jsonDF = ctx.read.json(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -671,7 +670,7 @@ class JsonSuite extends QueryTest { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonDF1 = read.schema(schema).json(path) + val jsonDF1 = ctx.read.schema(schema).json(path) assert(schema === jsonDF1.schema) @@ -688,7 +687,7 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val jsonDF2 = read.schema(schema).json(primitiveFieldAndType) + val jsonDF2 = ctx.read.schema(schema).json(primitiveFieldAndType) assert(schema === jsonDF2.schema) @@ -709,7 +708,7 @@ class JsonSuite extends QueryTest { test("Applying schemas with MapType") { val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) - val jsonWithSimpleMap = read.schema(schemaWithSimpleMap).json(mapType1) + val jsonWithSimpleMap = ctx.read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") @@ -737,7 +736,7 @@ class JsonSuite extends QueryTest { val schemaWithComplexMap = StructType( StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) - val jsonWithComplexMap = read.schema(schemaWithComplexMap).json(mapType2) + val jsonWithComplexMap = ctx.read.schema(schemaWithComplexMap).json(mapType2) jsonWithComplexMap.registerTempTable("jsonWithComplexMap") @@ -763,7 +762,7 @@ class JsonSuite extends QueryTest { } test("SPARK-2096 Correctly parse dot notations") { - val jsonDF = read.json(complexFieldAndType2) + val jsonDF = ctx.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -781,7 +780,7 @@ class JsonSuite extends QueryTest { } test("SPARK-3390 Complex arrays") { - val jsonDF = read.json(complexFieldAndType2) + val jsonDF = ctx.read.json(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -804,7 +803,7 @@ class JsonSuite extends QueryTest { } test("SPARK-3308 Read top level JSON arrays") { - val jsonDF = read.json(jsonArray) + val jsonDF = ctx.read.json(jsonArray) jsonDF.registerTempTable("jsonTable") checkAnswer( @@ -822,10 +821,10 @@ class JsonSuite extends QueryTest { test("Corrupt records") { // Test if we can query corrupt records. - val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord - TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - val jsonDF = read.json(corruptRecords) + val jsonDF = ctx.read.json(corruptRecords) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -875,11 +874,11 @@ class JsonSuite extends QueryTest { Row("]") :: Nil ) - TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } test("SPARK-4068: nulls in arrays") { - val jsonDF = read.json(nullsInArrays) + val jsonDF = ctx.read.json(nullsInArrays) jsonDF.registerTempTable("jsonTable") val schema = StructType( @@ -925,7 +924,7 @@ class JsonSuite extends QueryTest { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val df1 = createDataFrame(rowRDD1, schema1) + val df1 = ctx.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") val df2 = df1.toDF val result = df2.toJSON.collect() @@ -948,7 +947,7 @@ class JsonSuite extends QueryTest { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df3 = createDataFrame(rowRDD2, schema2) + val df3 = ctx.createDataFrame(rowRDD2, schema2) df3.registerTempTable("applySchema2") val df4 = df3.toDF val result2 = df4.toJSON.collect() @@ -956,8 +955,8 @@ class JsonSuite extends QueryTest { assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonDF = read.json(primitiveFieldAndType) - val primTable = read.json(jsonDF.toJSON) + val jsonDF = ctx.read.json(primitiveFieldAndType) + val primTable = ctx.read.json(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -969,8 +968,8 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val complexJsonDF = read.json(complexFieldAndType1) - val compTable = read.json(complexJsonDF.toJSON) + val complexJsonDF = ctx.read.json(complexFieldAndType1) + val compTable = ctx.read.json(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( @@ -1074,29 +1073,29 @@ class JsonSuite extends QueryTest { } test("SPARK-7565 MapType in JsonRDD") { - val useStreaming = getConf(SQLConf.USE_JACKSON_STREAMING_API, "true") - val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord - TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") + val useStreaming = ctx.getConf(SQLConf.USE_JACKSON_STREAMING_API, "true") + val oldColumnNameOfCorruptRecord = ctx.conf.columnNameOfCorruptRecord + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") val schemaWithSimpleMap = StructType( StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) try{ for (useStreaming <- List("true", "false")) { - setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) + ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) val temp = Utils.createTempDir().getPath - val df = read.schema(schemaWithSimpleMap).json(mapType1) + val df = ctx.read.schema(schemaWithSimpleMap).json(mapType1) df.write.mode("overwrite").parquet(temp) // order of MapType is not defined - assert(read.parquet(temp).count() == 5) + assert(ctx.read.parquet(temp).count() == 5) - val df2 = read.json(corruptRecords) + val df2 = ctx.read.json(corruptRecords) df2.write.mode("overwrite").parquet(temp) - checkAnswer(read.parquet(temp), df2.collect()) + checkAnswer(ctx.read.parquet(temp), df2.collect()) } } finally { - setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) - setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) + ctx.setConf(SQLConf.USE_JACKSON_STREAMING_API, useStreaming) + ctx.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, oldColumnNameOfCorruptRecord) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index 47a97a49daabb62c0a0b4faeddbf74d0d33c46b7..b6a6a8dc6a63c3e8f9f69bde381312d1a34d2cb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.json -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext -object TestJsonData { +trait TestJsonData { - val primitiveFieldAndType = - TestSQLContext.sparkContext.parallelize( + protected def ctx: SQLContext + + def primitiveFieldAndType: RDD[String] = + ctx.sparkContext.parallelize( """{"string":"this is a simple string.", "integer":10, "long":21474836470, @@ -32,8 +35,8 @@ object TestJsonData { "null":null }""" :: Nil) - val primitiveFieldValueTypeConflict = - TestSQLContext.sparkContext.parallelize( + def primitiveFieldValueTypeConflict: RDD[String] = + ctx.sparkContext.parallelize( """{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1, "num_bool":true, "num_str":13.1, "str_bool":"str1"}""" :: """{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null, @@ -43,15 +46,15 @@ object TestJsonData { """{"num_num_1":21474836570, "num_num_2":1.1, "num_num_3": 21474836470, "num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil) - val jsonNullStruct = - TestSQLContext.sparkContext.parallelize( + def jsonNullStruct: RDD[String] = + ctx.sparkContext.parallelize( """{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":{}}""" :: """{"nullstr":"","ip":"27.31.100.29","headers":""}""" :: """{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil) - val complexFieldValueTypeConflict = - TestSQLContext.sparkContext.parallelize( + def complexFieldValueTypeConflict: RDD[String] = + ctx.sparkContext.parallelize( """{"num_struct":11, "str_array":[1, 2, 3], "array":[], "struct_array":[], "struct": {}}""" :: """{"num_struct":{"field":false}, "str_array":null, @@ -61,23 +64,23 @@ object TestJsonData { """{"num_struct":{}, "str_array":["str1", "str2", 33], "array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil) - val arrayElementTypeConflict = - TestSQLContext.sparkContext.parallelize( + def arrayElementTypeConflict: RDD[String] = + ctx.sparkContext.parallelize( """{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}], "array2": [{"field":214748364700}, {"field":1}]}""" :: """{"array3": [{"field":"str"}, {"field":1}]}""" :: """{"array3": [1, 2, 3]}""" :: Nil) - val missingFields = - TestSQLContext.sparkContext.parallelize( + def missingFields: RDD[String] = + ctx.sparkContext.parallelize( """{"a":true}""" :: """{"b":21474836470}""" :: """{"c":[33, 44]}""" :: """{"d":{"field":true}}""" :: """{"e":"str"}""" :: Nil) - val complexFieldAndType1 = - TestSQLContext.sparkContext.parallelize( + def complexFieldAndType1: RDD[String] = + ctx.sparkContext.parallelize( """{"struct":{"field1": true, "field2": 92233720368547758070}, "structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]}, "arrayOfString":["str1", "str2"], @@ -92,8 +95,8 @@ object TestJsonData { "arrayOfArray2":[[1, 2, 3], [1.1, 2.1, 3.1]] }""" :: Nil) - val complexFieldAndType2 = - TestSQLContext.sparkContext.parallelize( + def complexFieldAndType2: RDD[String] = + ctx.sparkContext.parallelize( """{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}], "complexArrayOfStruct": [ { @@ -146,16 +149,16 @@ object TestJsonData { ]] }""" :: Nil) - val mapType1 = - TestSQLContext.sparkContext.parallelize( + def mapType1: RDD[String] = + ctx.sparkContext.parallelize( """{"map": {"a": 1}}""" :: """{"map": {"b": 2}}""" :: """{"map": {"c": 3}}""" :: """{"map": {"c": 1, "d": 4}}""" :: """{"map": {"e": null}}""" :: Nil) - val mapType2 = - TestSQLContext.sparkContext.parallelize( + def mapType2: RDD[String] = + ctx.sparkContext.parallelize( """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: """{"map": {"b": {"field2": 2}}}""" :: """{"map": {"c": {"field1": [], "field2": 4}}}""" :: @@ -163,22 +166,22 @@ object TestJsonData { """{"map": {"e": null}}""" :: """{"map": {"f": {"field1": null}}}""" :: Nil) - val nullsInArrays = - TestSQLContext.sparkContext.parallelize( + def nullsInArrays: RDD[String] = + ctx.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: """{"field2":[null, [{"Test":1}]]}""" :: """{"field3":[[null], [{"Test":"2"}]]}""" :: """{"field4":[[null, [1,2,3]]]}""" :: Nil) - val jsonArray = - TestSQLContext.sparkContext.parallelize( + def jsonArray: RDD[String] = + ctx.sparkContext.parallelize( """[{"a":"str_a_1"}]""" :: """[{"a":"str_a_2"}, {"b":"str_b_3"}]""" :: """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """[]""" :: Nil) - val corruptRecords = - TestSQLContext.sparkContext.parallelize( + def corruptRecords: RDD[String] = + ctx.sparkContext.parallelize( """{""" :: """""" :: """{"a":1, b:2}""" :: @@ -186,6 +189,5 @@ object TestJsonData { """{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" :: """]""" :: Nil) - val empty = - TestSQLContext.sparkContext.parallelize(Seq[String]()) + def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 4aa5bcb7fdbca8dd110c81e865d6ea47e6652970..17f5f9a491e6bbb7d57595158cbe0b3e60bf891a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.sources.LogicalRelation -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} @@ -42,7 +41,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf} * data type is nullable. */ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext private def checkFilterPredicate( df: DataFrame, @@ -312,7 +311,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { } class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") @@ -341,7 +340,7 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA } class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 7f7c2cc1a6c26b8b67f47625b00fefa03aee507f..2b6a27032e637f93b1cce809e8d6919f7c1d5d29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -36,9 +36,6 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode} @@ -66,9 +63,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS * A test suite that tests basic Parquet I/O. */ class ParquetIOSuiteBase extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext - - import sqlContext.implicits.localSeqToDataFrameHolder + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ /** * Writes `data` to a Parquet file, reads it back and check file contents. @@ -104,7 +100,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { test("fixed-length decimals") { def makeDecimalRDD(decimal: DecimalType): DataFrame = - sparkContext + sqlContext.sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) .toDF() @@ -115,7 +111,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) data.write.parquet(dir.getCanonicalPath) - checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } @@ -123,7 +119,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { intercept[Throwable] { withTempPath { dir => makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).collect() + sqlContext.read.parquet(dir.getCanonicalPath).collect() } } @@ -131,14 +127,14 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { intercept[Throwable] { withTempPath { dir => makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).collect() + sqlContext.read.parquet(dir.getCanonicalPath).collect() } } } test("date type") { def makeDateRDD(): DataFrame = - sparkContext + sqlContext.sparkContext .parallelize(0 to 1000) .map(i => Tuple1(DateUtils.toJavaDate(i))) .toDF() @@ -147,7 +143,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempPath { dir => val data = makeDateRDD() data.write.parquet(dir.getCanonicalPath) - checkAnswer(read.parquet(dir.getCanonicalPath), data.collect().toSeq) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } @@ -236,7 +232,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { def checkCompressionCodec(codec: CompressionCodecName): Unit = { withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) { withParquetFile(data) { path => - assertResult(conf.parquetCompressionCodec.toUpperCase) { + assertResult(sqlContext.conf.parquetCompressionCodec.toUpperCase) { compressionCodecFor(path) } } @@ -244,7 +240,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } // Checks default compression codec - checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec)) + checkCompressionCodec(CompressionCodecName.fromConf(sqlContext.conf.parquetCompressionCodec)) checkCompressionCodec(CompressionCodecName.UNCOMPRESSED) checkCompressionCodec(CompressionCodecName.GZIP) @@ -283,7 +279,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withTempDir { dir => val path = new Path(dir.toURI.toString, "part-r-0.parquet") makeRawParquetFile(path) - checkAnswer(read.parquet(path.toString), (0 until 10).map { i => + checkAnswer(sqlContext.read.parquet(path.toString), (0 until 10).map { i => Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) }) } @@ -312,7 +308,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile((1 to 10).map(i => (i, i.toString))) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Overwrite).save(file) - checkAnswer(read.parquet(file), newData.map(Row.fromTuple)) + checkAnswer(sqlContext.read.parquet(file), newData.map(Row.fromTuple)) } } @@ -321,7 +317,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Ignore).save(file) - checkAnswer(read.parquet(file), data.map(Row.fromTuple)) + checkAnswer(sqlContext.read.parquet(file), data.map(Row.fromTuple)) } } @@ -341,7 +337,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { withParquetFile(data) { file => val newData = (11 to 20).map(i => (i, i.toString)) newData.toDF().write.format("parquet").mode(SaveMode.Append).save(file) - checkAnswer(read.parquet(file), (data ++ newData).map(Row.fromTuple)) + checkAnswer(sqlContext.read.parquet(file), (data ++ newData).map(Row.fromTuple)) } } @@ -369,11 +365,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { val path = new Path(location.getCanonicalPath) ParquetFileWriter.writeMetadataFile( - sparkContext.hadoopConfiguration, + sqlContext.sparkContext.hadoopConfiguration, path, new Footer(path, new ParquetMetadata(fileMetadata, Nil)) :: Nil) - assertResult(read.parquet(path.toString).schema) { + assertResult(sqlContext.read.parquet(path.toString).schema) { StructType( StructField("a", BooleanType, nullable = false) :: StructField("b", IntegerType, nullable = false) :: @@ -406,7 +402,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") @@ -430,7 +426,7 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA } class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 3b29979452ad99eacdae03cbe693692df216fd5e..8979a0a210a42760a0bf118dc07cb463c832f340 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -14,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark.sql.parquet import java.io.File @@ -28,7 +29,6 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.sources.PartitioningUtils._ import org.apache.spark.sql.sources.{LogicalRelation, Partition, PartitionSpec} -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, QueryTest, Row, SQLContext} @@ -39,10 +39,10 @@ case class ParquetData(intField: Int, stringField: String) case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { - override val sqlContext: SQLContext = TestSQLContext - import sqlContext._ + override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext import sqlContext.implicits._ + import sqlContext.sql val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" @@ -190,8 +190,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { // Introduce _temporary dir to the base dir the robustness of the schema discovery process. new File(base.getCanonicalPath, "_temporary").mkdir() - println("load the partitioned table") - read.parquet(base.getCanonicalPath).registerTempTable("t") + sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -238,7 +237,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.parquet(base.getCanonicalPath).registerTempTable("t") + sqlContext.read.parquet(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -286,7 +285,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath) + val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -326,7 +325,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - val parquetRelation = read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath) + val parquetRelation = sqlContext.read.format("parquet").load(base.getCanonicalPath) parquetRelation.registerTempTable("t") withTempTable("t") { @@ -358,7 +357,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"), makePartitionDir(base, defaultPartitionName, "pi" -> 2)) - read.format("org.apache.spark.sql.parquet").load(base.getCanonicalPath).registerTempTable("t") + sqlContext.read.format("parquet").load(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -371,7 +370,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { test("SPARK-7749 Non-partitioned table should have empty partition spec") { withTempPath { dir => (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) - val queryExecution = read.parquet(dir.getCanonicalPath).queryExecution + val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { case LogicalRelation(relation: ParquetRelation2) => assert(relation.partitionSpec === PartitionSpec.emptySpec) @@ -385,7 +384,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { withTempPath { dir => val df = Seq("/", "[]", "?").zipWithIndex.map(_.swap).toDF("i", "s") df.write.format("parquet").partitionBy("s").save(dir.getCanonicalPath) - checkAnswer(read.parquet(dir.getCanonicalPath), df.collect()) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), df.collect()) } } @@ -425,12 +424,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { } val schema = StructType(partitionColumns :+ StructField(s"i", StringType)) - val df = createDataFrame(sparkContext.parallelize(row :: Nil), schema) + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(row :: Nil), schema) withTempPath { dir => df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) val fields = schema.map(f => Column(f.name).cast(f.dataType)) - checkAnswer(read.load(dir.toString).select(fields: _*), row) + checkAnswer(sqlContext.read.load(dir.toString).select(fields: _*), row) } } @@ -446,7 +445,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { Files.touch(new File(s"${dir.getCanonicalPath}/b=1", ".DS_Store")) Files.createParentDirs(new File(s"${dir.getCanonicalPath}/b=1/c=1/.foo/bar")) - checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df) + checkAnswer(sqlContext.read.format("parquet").load(dir.getCanonicalPath), df) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 304936fb2be8e0898ec67b94f5e0b83b75cda0ca..de0107a3618159b58dced813886e29121b519da6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -22,14 +22,14 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.types._ import org.apache.spark.sql.{SQLConf, QueryTest} import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ /** * A test suite that tests various Parquet queries. */ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { - val sqlContext = TestSQLContext + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ + import sqlContext.sql test("simple select queries") { withParquetTable((0 until 10).map(i => (i, i.toString)), "t") { @@ -40,22 +40,22 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { test("appending") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT INTO TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), (data ++ data).map(Row.fromTuple)) + checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple)) } - catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(Seq("tmp")) } test("overwriting") { val data = (0 until 10).map(i => (i, i.toString)) - createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") + sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp") withParquetTable(data, "t") { sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp") - checkAnswer(table("t"), data.map(Row.fromTuple)) + checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple)) } - catalog.unregisterTable(Seq("tmp")) + sqlContext.catalog.unregisterTable(Seq("tmp")) } test("self-join") { @@ -118,7 +118,7 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { val schema = StructType(List(StructField("d", DecimalType(18, 0), false), StructField("time", TimestampType, false)).toArray) withTempPath { file => - val df = sqlContext.createDataFrame(sparkContext.parallelize(data), schema) + val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema) df.write.parquet(file.getCanonicalPath) val df2 = sqlContext.read.parquet(file.getCanonicalPath) checkAnswer(df2, df.collect().toSeq) @@ -127,7 +127,7 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { } class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true") @@ -139,7 +139,7 @@ class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAnd } class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - val originalConf = sqlContext.conf.parquetUseDataSourceApi + private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi override protected def beforeAll(): Unit = { sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 8b1745124b8e13d81b48fadb918c0fc73199242e..171a656f0e01ede767a214b238edf1d89c1bc34b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -24,11 +24,10 @@ import org.apache.parquet.schema.MessageTypeParser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { - val sqlContext = TestSQLContext + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext /** * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 516ba373f41d290e234fbfec7173280c217c427a..eb15a1609f1d0a2b05a0f1d6c01509e0d4b833b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -33,8 +33,6 @@ import org.apache.spark.sql.{DataFrame, SaveMode} * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ private[sql] trait ParquetTest extends SQLTestUtils { - import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder} - import sqlContext.sparkContext /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` @@ -44,7 +42,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sparkContext.parallelize(data).toDF().write.parquet(file.getCanonicalPath) + sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -75,7 +73,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 17a8b0cca09dfe0c440639d94e5c451a7727dfae..ac4a00a6f3dacb8bfd7dfe339f714c6191adf0db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -25,11 +25,9 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils trait SQLTestUtils { - val sqlContext: SQLContext + def sqlContext: SQLContext - import sqlContext.{conf, sparkContext} - - protected def configuration = sparkContext.hadoopConfiguration + protected def configuration = sqlContext.sparkContext.hadoopConfiguration /** * Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL @@ -39,12 +37,12 @@ trait SQLTestUtils { */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(conf.getConf(key)).toOption) - (keys, values).zipped.foreach(conf.setConf) + val currentValues = keys.map(key => Try(sqlContext.conf.getConf(key)).toOption) + (keys, values).zipped.foreach(sqlContext.conf.setConf) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => conf.setConf(key, value) - case (key, None) => conf.unsetConf(key) + case (key, Some(value)) => sqlContext.conf.setConf(key, value) + case (key, None) => sqlContext.conf.unsetConf(key) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 57c23fe77f8b5a7d81974a64239c7b9a5a52cff2..b384fb39f3d66676663c543dd5328910eb13efba 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -52,9 +52,6 @@ case class Contact(name: String, phone: String) case class Person(name: String, age: Int, contacts: Seq[Contact]) class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { - override val sqlContext = TestHive - - import TestHive.read def getTempFilePath(prefix: String, suffix: String = ""): File = { val tempFile = File.createTempFile(prefix, suffix) @@ -69,7 +66,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - read.format("orc").load(file), + sqlContext.read.format("orc").load(file), data.toDF().collect()) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 750f0b04aaa87bfebeeb649f5f7f5d258f5c987b..5daf691aa8c53b1797d82c236b5704e0d58549f0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -22,13 +22,11 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql._ private[sql] trait OrcTest extends SQLTestUtils { - protected def hiveContext = sqlContext.asInstanceOf[HiveContext] + lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive import sqlContext.sparkContext import sqlContext.implicits._ @@ -53,7 +51,7 @@ private[sql] trait OrcTest extends SQLTestUtils { protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withOrcFile(data)(path => f(hiveContext.read.format("orc").load(path))) + withOrcFile(data)(path => f(sqlContext.read.format("orc").load(path))) } /** @@ -65,7 +63,7 @@ private[sql] trait OrcTest extends SQLTestUtils { (data: Seq[T], tableName: String) (f: => Unit): Unit = { withOrcDataFrame(data) { df => - hiveContext.registerDataFrameAsTable(df, tableName) + sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } }