diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 6ca5390cde23e782a718b43981d1fc0e622b29b6..8631e247c6c055b94af48b4c2ba8adea65bf83cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.sql.{Timestamp, Date} +import org.apache.spark.sql.test.TestSQLContext import org.scalatest.BeforeAndAfterAll import org.apache.spark.rdd.ShuffledRDD @@ -26,7 +27,6 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.{ShuffleDependency, SparkFunSuite} import org.apache.spark.sql.types._ import org.apache.spark.sql.Row -import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest} class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { @@ -74,11 +74,13 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll var numShufflePartitions: Int = _ var useSerializer2: Boolean = _ + protected lazy val ctx = TestSQLContext + override def beforeAll(): Unit = { - numShufflePartitions = conf.numShufflePartitions - useSerializer2 = conf.useSqlSerializer2 + numShufflePartitions = ctx.conf.numShufflePartitions + useSerializer2 = ctx.conf.useSqlSerializer2 - sql("set spark.sql.useSerializer2=true") + ctx.sql("set spark.sql.useSerializer2=true") val supportedTypes = Seq(StringType, BinaryType, NullType, BooleanType, @@ -94,7 +96,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll // Create a RDD with all data types supported by SparkSqlSerializer2. val rdd = - sparkContext.parallelize((1 to 1000), 10).map { i => + ctx.sparkContext.parallelize((1 to 1000), 10).map { i => Row( s"str${i}: test serializer2.", s"binary${i}: test serializer2.".getBytes("UTF-8"), @@ -112,15 +114,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll new Timestamp(i)) } - createDataFrame(rdd, schema).registerTempTable("shuffle") + ctx.createDataFrame(rdd, schema).registerTempTable("shuffle") super.beforeAll() } override def afterAll(): Unit = { - dropTempTable("shuffle") - sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") - sql(s"set spark.sql.useSerializer2=$useSerializer2") + ctx.dropTempTable("shuffle") + ctx.sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions") + ctx.sql(s"set spark.sql.useSerializer2=$useSerializer2") super.afterAll() } @@ -141,16 +143,16 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll } test("key schema and value schema are not nulls") { - val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") + val df = ctx.sql(s"SELECT DISTINCT ${allColumns} FROM shuffle") checkSerializer(df.queryExecution.executedPlan, serializerClass) checkAnswer( df, - table("shuffle").collect()) + ctx.table("shuffle").collect()) } test("key schema is null") { val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",") - val df = sql(s"SELECT $aggregations FROM shuffle") + val df = ctx.sql(s"SELECT $aggregations FROM shuffle") checkSerializer(df.queryExecution.executedPlan, serializerClass) checkAnswer( df, @@ -158,15 +160,14 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll } test("value schema is null") { - val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0") + val df = ctx.sql(s"SELECT col0 FROM shuffle ORDER BY col0") checkSerializer(df.queryExecution.executedPlan, serializerClass) - assert( - df.map(r => r.getString(0)).collect().toSeq === - table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) + assert(df.map(r => r.getString(0)).collect().toSeq === + ctx.table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq) } test("no map output field") { - val df = sql(s"SELECT 1 + 1 FROM shuffle") + val df = ctx.sql(s"SELECT 1 + 1 FROM shuffle") checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer]) } } @@ -177,8 +178,8 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite { super.beforeAll() // Sort merge will not be triggered. val bypassMergeThreshold = - sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}") + ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}") } } @@ -189,7 +190,7 @@ class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite super.beforeAll() // To trigger the sort merge. val bypassMergeThreshold = - sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) - sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}") + ctx.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200) + ctx.sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index d2d1011b8e917170e90cf54eb7b04fd2c68bc585..a71088430bfd556295085033033cf9a117943e3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -26,18 +26,20 @@ import org.apache.spark.util.Utils class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql + + private lazy val sparkContext = caseInsensitiveContext.sparkContext var path: File = null override def beforeAll(): Unit = { path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - read.json(rdd).registerTempTable("jt") + caseInsensitiveContext.read.json(rdd).registerTempTable("jt") } override def afterAll(): Unit = { - dropTempTable("jt") + caseInsensitiveContext.dropTempTable("jt") } after { @@ -59,7 +61,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a, b FROM jsonTable"), sql("SELECT a, b FROM jt").collect()) - dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jsonTable") } test("CREATE TEMPORARY TABLE AS SELECT based on the file without write permission") { @@ -129,7 +131,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT * FROM jsonTable"), sql("SELECT a * 4 FROM jt").collect()) - dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jsonTable") // Explicitly delete the data. if (path.exists()) Utils.deleteRecursively(path) @@ -147,7 +149,7 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT * FROM jsonTable"), sql("SELECT b FROM jt").collect()) - dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jsonTable") } test("CREATE TEMPORARY TABLE AS SELECT with IF NOT EXISTS is not allowed") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 5c3467158a01b7701b7ce846f3f009d98e0c0e0d..51d22b6a1378abaa87351e9874156a85be9a2a1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -63,19 +63,18 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo } class DDLTestSuite extends DataSourceTest { - import caseInsensitiveContext._ before { - sql( - """ - |CREATE TEMPORARY TABLE ddlPeople - |USING org.apache.spark.sql.sources.DDLScanSource - |OPTIONS ( - | From '1', - | To '10', - | Table 'test1' - |) - """.stripMargin) + caseInsensitiveContext.sql( + """ + |CREATE TEMPORARY TABLE ddlPeople + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) } sqlTest( @@ -100,7 +99,8 @@ class DDLTestSuite extends DataSourceTest { )) test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") { - val attributes = sql("describe ddlPeople").queryExecution.executedPlan.output + val attributes = caseInsensitiveContext.sql("describe ddlPeople") + .queryExecution.executedPlan.output assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment")) assert(attributes.map(_.dataType).toSet === Set(StringType)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 24ed665c67d2e30ac09a0c2dbbe8ba1c43cb30d4..3f77960d09246a1365c29da5d72979d7b74d290e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -17,14 +17,18 @@ package org.apache.spark.sql.sources +import org.scalatest.BeforeAndAfter + import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.test.TestSQLContext -import org.scalatest.BeforeAndAfter + abstract class DataSourceTest extends QueryTest with BeforeAndAfter { // We want to test some edge cases. - implicit val caseInsensitiveContext = new SQLContext(TestSQLContext.sparkContext) + protected implicit lazy val caseInsensitiveContext = { + val ctx = new SQLContext(TestSQLContext.sparkContext) + ctx.setConf(SQLConf.CASE_SENSITIVE, "false") + ctx + } - caseInsensitiveContext.setConf(SQLConf.CASE_SENSITIVE, "false") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index db94b1f3e8926ca26942540d61a94239ef2bd812..81b3a0f0c5b3ab285f84e63b045b10fce8afe51e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -97,7 +97,7 @@ object FiltersPushed { class FilteredScanSuite extends DataSourceTest { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql before { sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 6f375ef36237d8559348886efd3cab0278852496..0b7c46c482c889e41ca22057bd30099a8755a766 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -26,14 +26,16 @@ import org.apache.spark.util.Utils class InsertSuite extends DataSourceTest with BeforeAndAfterAll { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql + + private lazy val sparkContext = caseInsensitiveContext.sparkContext var path: File = null override def beforeAll: Unit = { path = Utils.createTempDir() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - read.json(rdd).registerTempTable("jt") + caseInsensitiveContext.read.json(rdd).registerTempTable("jt") sql( s""" |CREATE TEMPORARY TABLE jsonTable (a int, b string) @@ -45,8 +47,8 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } override def afterAll: Unit = { - dropTempTable("jsonTable") - dropTempTable("jt") + caseInsensitiveContext.dropTempTable("jsonTable") + caseInsensitiveContext.dropTempTable("jt") Utils.deleteRecursively(path) } @@ -109,7 +111,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { // Writing the table to less part files. val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 5) - read.json(rdd1).registerTempTable("jt1") + caseInsensitiveContext.read.json(rdd1).registerTempTable("jt1") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt1 @@ -121,7 +123,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { // Writing the table to more part files. val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""), 10) - read.json(rdd2).registerTempTable("jt2") + caseInsensitiveContext.read.json(rdd2).registerTempTable("jt2") sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jt2 @@ -140,8 +142,8 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { (1 to 10).map(i => Row(i * 10, s"str$i")) ) - dropTempTable("jt1") - dropTempTable("jt2") + caseInsensitiveContext.dropTempTable("jt1") + caseInsensitiveContext.dropTempTable("jt2") } test("INSERT INTO not supported for JSONRelation for now") { @@ -154,13 +156,14 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { } test("save directly to the path of a JSON table") { - table("jt").selectExpr("a * 5 as a", "b").write.mode(SaveMode.Overwrite).json(path.toString) + caseInsensitiveContext.table("jt").selectExpr("a * 5 as a", "b") + .write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i * 5, s"str$i")) ) - table("jt").write.mode(SaveMode.Overwrite).json(path.toString) + caseInsensitiveContext.table("jt").write.mode(SaveMode.Overwrite).json(path.toString) checkAnswer( sql("SELECT a, b FROM jsonTable"), (1 to 10).map(i => Row(i, s"str$i")) @@ -181,7 +184,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { test("Caching") { // Cached Query Execution - cacheTable("jsonTable") + caseInsensitiveContext.cacheTable("jsonTable") assertCached(sql("SELECT * FROM jsonTable")) checkAnswer( sql("SELECT * FROM jsonTable"), @@ -220,7 +223,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { sql("SELECT a * 2, b FROM jt").collect()) // Verify uncaching - uncacheTable("jsonTable") + caseInsensitiveContext.uncacheTable("jsonTable") assertCached(sql("SELECT * FROM jsonTable"), 0) } @@ -251,6 +254,6 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { "It is not allowed to insert into a table that is not an InsertableRelation." ) - dropTempTable("oneToTen") + caseInsensitiveContext.dropTempTable("oneToTen") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index c2bc52e2120c1662a4ce0db27499f765f8140994..257526feab9452501819c2e48042cbe7dbfcb99b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -52,10 +52,9 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo } class PrunedScanSuite extends DataSourceTest { - import caseInsensitiveContext._ before { - sql( + caseInsensitiveContext.sql( """ |CREATE TEMPORARY TABLE oneToTenPruned |USING org.apache.spark.sql.sources.PrunedScanSource @@ -115,7 +114,7 @@ class PrunedScanSuite extends DataSourceTest { def testPruning(sqlString: String, expectedColumns: String*): Unit = { test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") { - val queryExecution = sql(sqlString).queryExecution + val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution val rawPlan = queryExecution.executedPlan.collect { case p: execution.PhysicalRDD => p } match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 274c652dd14d6a2fd71fc7ef76b17d5608756eb2..b032515a9d28c671fff33474ba050a34a5f24259 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -27,7 +27,9 @@ import org.apache.spark.util.Utils class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql + + private lazy val sparkContext = caseInsensitiveContext.sparkContext var originalDefaultSource: String = null @@ -36,60 +38,63 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll { var df: DataFrame = null override def beforeAll(): Unit = { - originalDefaultSource = conf.defaultDataSourceName + originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName path = Utils.createTempDir() path.delete() val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}""")) - df = read.json(rdd) + df = caseInsensitiveContext.read.json(rdd) df.registerTempTable("jsonTable") } override def afterAll(): Unit = { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) } after { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource) Utils.deleteRecursively(path) } def checkLoad(): Unit = { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") - checkAnswer(read.load(path.toString), df.collect()) + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + checkAnswer(caseInsensitiveContext.read.load(path.toString), df.collect()) // Test if we can pick up the data source name passed in load. - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") - checkAnswer(read.format("json").load(path.toString), df.collect()) - checkAnswer(read.format("json").load(path.toString), df.collect()) + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), df.collect()) + checkAnswer(caseInsensitiveContext.read.format("json").load(path.toString), df.collect()) val schema = StructType(StructField("b", StringType, true) :: Nil) checkAnswer( - read.format("json").schema(schema).load(path.toString), + caseInsensitiveContext.read.format("json").schema(schema).load(path.toString), sql("SELECT b FROM jsonTable").collect()) } test("save with path and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") df.write.save(path.toString) checkLoad() } test("save with string mode and path, and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") + caseInsensitiveContext.conf.setConf( + SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json") path.createNewFile() df.write.mode("overwrite").save(path.toString) checkLoad() } test("save with path and datasource, and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") df.write.json(path.toString) checkLoad() } test("save with data source and options, and load") { - conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") + caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name") df.write.mode(SaveMode.ErrorIfExists).json(path.toString) checkLoad() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 77af04a491742ca5e8de25c6ee764d3f9bdb8932..5d4ecd810862cf2d1d4718e22f5d4ea66039184b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -88,9 +88,9 @@ case class AllDataTypesScan( } class TableScanSuite extends DataSourceTest { - import caseInsensitiveContext._ + import caseInsensitiveContext.sql - var tableWithSchemaExpected = (1 to 10).map { i => + private lazy val tableWithSchemaExpected = (1 to 10).map { i => Row( s"str_$i", s"str_$i", @@ -215,7 +215,7 @@ class TableScanSuite extends DataSourceTest { Nil ) - assert(expectedSchema == table("tableWithSchema").schema) + assert(expectedSchema == caseInsensitiveContext.table("tableWithSchema").schema) checkAnswer( sql( @@ -270,7 +270,7 @@ class TableScanSuite extends DataSourceTest { test("Caching") { // Cached Query Execution - cacheTable("oneToTen") + caseInsensitiveContext.cacheTable("oneToTen") assertCached(sql("SELECT * FROM oneToTen")) checkAnswer( sql("SELECT * FROM oneToTen"), @@ -297,7 +297,7 @@ class TableScanSuite extends DataSourceTest { (2 to 10).map(i => Row(i, i - 1)).toSeq) // Verify uncaching - uncacheTable("oneToTen") + caseInsensitiveContext.uncacheTable("oneToTen") assertCached(sql("SELECT * FROM oneToTen"), 0) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index 4990092df6a99d97bad7d68b734e9569f549639e..017bc2adc103b40b3910f21cb01b9b5ac2650de3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -20,16 +20,17 @@ package org.apache.spark.sql.hive import com.google.common.io.Files import org.apache.spark.sql.{QueryTest, _} -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.util.Utils class QueryPartitionSuite extends QueryTest { - import org.apache.spark.sql.hive.test.TestHive.implicits._ + + private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + import ctx.implicits._ + import ctx.sql test("SPARK-5068: query data when path doesn't exist"){ - val testData = TestHive.sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (1 to 10).map(i => TestData(i, i.toString))).toDF() testData.registerTempTable("testData") @@ -48,8 +49,8 @@ class QueryPartitionSuite extends QueryTest { // test for the exist path checkAnswer(sql("select key,value from table_with_partition"), - testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect - ++ testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect) + testData.toDF.collect ++ testData.toDF.collect + ++ testData.toDF.collect ++ testData.toDF.collect) // delete the path of one partition tmpDir.listFiles @@ -58,8 +59,7 @@ class QueryPartitionSuite extends QueryTest { // test for after delete the path checkAnswer(sql("select key,value from table_with_partition"), - testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect - ++ testData.toSchemaRDD.collect) + testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect) sql("DROP TABLE table_with_partition") sql("DROP TABLE createAndInsertTest") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala index a492ecf203d17967e57ed42d630f438ef610c6b4..93dcb10f7a29672c290818aa5e2ac3e6d73da3ca 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.hive import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.hive.test.TestHive class SerializationSuite extends SparkFunSuite { test("[SPARK-5840] HiveContext should be serializable") { - val hiveContext = TestHive + val hiveContext = org.apache.spark.sql.hive.test.TestHive hiveContext.hiveconf val serializer = new JavaSerializer(new SparkConf()).newInstance() val bytes = serializer.serialize(hiveContext) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index e16e530555aee6fe71b99fa82ca2528669cf8e77..78c94e6490e362334cc4bc6c7fb4d1c6ad7bca39 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -23,13 +23,18 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.hive.test.TestHive -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.execution._ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { - TestHive.reset() - TestHive.cacheTables = false + + private lazy val ctx: HiveContext = { + val ctx = org.apache.spark.sql.hive.test.TestHive + ctx.reset() + ctx.cacheTables = false + ctx + } + + import ctx.sql test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { @@ -72,7 +77,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = - catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes + ctx.catalog.lookupRelation(Seq(tableName)).statistics.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -106,7 +111,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { |SELECT * FROM src """.stripMargin).collect() - assert(queryTotalSize("analyzeTable_part") === conf.defaultSizeInBytes) + assert(queryTotalSize("analyzeTable_part") === ctx.conf.defaultSizeInBytes) sql("ANALYZE TABLE analyzeTable_part COMPUTE STATISTICS noscan") @@ -117,9 +122,9 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { // Try to analyze a temp table sql("""SELECT * FROM src""").registerTempTable("tempTable") intercept[UnsupportedOperationException] { - analyze("tempTable") + ctx.analyze("tempTable") } - catalog.unregisterTable(Seq("tempTable")) + ctx.catalog.unregisterTable(Seq("tempTable")) } test("estimates the size of a test MetastoreRelation") { @@ -147,8 +152,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { val sizes = df.queryExecution.analyzed.collect { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(0) <= conf.autoBroadcastJoinThreshold - && sizes(1) <= conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold + && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -159,8 +164,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, expectedAnswer) // check correctness of output - TestHive.conf.settings.synchronized { - val tmp = conf.autoBroadcastJoinThreshold + ctx.conf.settings.synchronized { + val tmp = ctx.conf.autoBroadcastJoinThreshold sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""") df = sql(query) @@ -203,8 +208,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { .isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } - assert(sizes.size === 2 && sizes(1) <= conf.autoBroadcastJoinThreshold - && sizes(0) <= conf.autoBroadcastJoinThreshold, + assert(sizes.size === 2 && sizes(1) <= ctx.conf.autoBroadcastJoinThreshold + && sizes(0) <= ctx.conf.autoBroadcastJoinThreshold, s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be @@ -217,8 +222,8 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { checkAnswer(df, answer) // check correctness of output - TestHive.conf.settings.synchronized { - val tmp = conf.autoBroadcastJoinThreshold + ctx.conf.settings.synchronized { + val tmp = ctx.conf.autoBroadcastJoinThreshold sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") df = sql(leftSemiJoinQuery) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 8245047626d57bf71c51559b11cdfde715418dd8..4056dee7775748f4eb98ce4011c25f5ce98e13ef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,20 +17,20 @@ package org.apache.spark.sql.hive -/* Implicits */ - import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.hive.test.TestHive._ case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { + + private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + test("UDF case insensitive") { - udf.register("random0", () => { Math.random() }) - udf.register("RANDOM1", () => { Math.random() }) - udf.register("strlenScala", (_: String).length + (_: Int)) - assert(sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) - assert(sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) + ctx.udf.register("random0", () => { Math.random() }) + ctx.udf.register("RANDOM1", () => { Math.random() }) + ctx.udf.register("strlenScala", (_: String).length + (_: Int)) + assert(ctx.sql("SELECT RANDOM0() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) + assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 74095426741e3ef5883d1b6dc94bfe5bdfb58ad6..8787663a98f8fbffa163a386522939131cf64343 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { - override val sqlContext: SQLContext = TestHive + override lazy val sqlContext: SQLContext = TestHive - import sqlContext._ + import sqlContext.sql import sqlContext.implicits._ val dataSourceName = classOf[SimpleTextSource].getCanonicalName @@ -43,19 +43,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { StructField("a", IntegerType, nullable = false), StructField("b", StringType, nullable = false))) - val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b") + lazy val testDF = (1 to 3).map(i => (i, s"val_$i")).toDF("a", "b") - val partitionedTestDF1 = (for { + lazy val partitionedTestDF1 = (for { i <- 1 to 3 p2 <- Seq("foo", "bar") } yield (i, s"val_$i", 1, p2)).toDF("a", "b", "p1", "p2") - val partitionedTestDF2 = (for { + lazy val partitionedTestDF2 = (for { i <- 1 to 3 p2 <- Seq("foo", "bar") } yield (i, s"val_$i", 2, p2)).toDF("a", "b", "p1", "p2") - val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2) + lazy val partitionedTestDF = partitionedTestDF1.unionAll(partitionedTestDF2) def checkQueries(df: DataFrame): Unit = { // Selects everything @@ -103,7 +103,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { testDF.write.mode(SaveMode.Overwrite).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("path", file.getCanonicalPath) .option("dataSchema", dataSchema.json) .load(), @@ -117,7 +117,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { testDF.write.mode(SaveMode.Append).format(dataSourceName).save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath).orderBy("a"), testDF.unionAll(testDF).orderBy("a").collect()) @@ -151,7 +151,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .save(file.getCanonicalPath) checkQueries( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath)) } @@ -172,7 +172,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.collect()) @@ -194,7 +194,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.unionAll(partitionedTestDF).collect()) @@ -216,7 +216,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .save(file.getCanonicalPath) checkAnswer( - read.format(dataSourceName) + sqlContext.read.format(dataSourceName) .option("dataSchema", dataSchema.json) .load(file.getCanonicalPath), partitionedTestDF.collect()) @@ -252,7 +252,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), testDF.collect()) + checkAnswer(sqlContext.table("t"), testDF.collect()) } } @@ -261,7 +261,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { testDF.write.format(dataSourceName).mode(SaveMode.Append).saveAsTable("t") withTable("t") { - checkAnswer(table("t"), testDF.unionAll(testDF).orderBy("a").collect()) + checkAnswer(sqlContext.table("t"), testDF.unionAll(testDF).orderBy("a").collect()) } } @@ -280,7 +280,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { withTempTable("t") { testDF.write.format(dataSourceName).mode(SaveMode.Ignore).saveAsTable("t") - assert(table("t").collect().isEmpty) + assert(sqlContext.table("t").collect().isEmpty) } } @@ -291,7 +291,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkQueries(table("t")) + checkQueries(sqlContext.table("t")) } } @@ -311,7 +311,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), partitionedTestDF.collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) } } @@ -331,7 +331,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.unionAll(partitionedTestDF).collect()) } } @@ -351,7 +351,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), partitionedTestDF.collect()) + checkAnswer(sqlContext.table("t"), partitionedTestDF.collect()) } } @@ -400,7 +400,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .partitionBy("p1", "p2") .saveAsTable("t") - assert(table("t").collect().isEmpty) + assert(sqlContext.table("t").collect().isEmpty) } } @@ -412,7 +412,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .partitionBy("p1", "p2") .save(file.getCanonicalPath) - val df = read + val df = sqlContext.read .format(dataSourceName) .option("dataSchema", dataSchema.json) .load(s"${file.getCanonicalPath}/p1=*/p2=???") @@ -452,7 +452,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTempTable("t") { - checkAnswer(table("t"), input.collect()) + checkAnswer(sqlContext.table("t"), input.collect()) } } } @@ -467,7 +467,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { .saveAsTable("t") withTable("t") { - checkAnswer(table("t"), df.select('b, 'c, 'a).collect()) + checkAnswer(sqlContext.table("t"), df.select('b, 'c, 'a).collect()) } } }