diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java index da623d1d15702c4c986f86b8220b86041b250025..7bda219243bf58c19a785b2ff88c1109ac9f5b81 100644 --- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -56,7 +56,7 @@ public class JavaDefaultReadWriteSuite extends SharedSparkSession { } catch (IOException e) { // expected } - instance.write().context(spark.wrapped()).overwrite().save(outputPath); + instance.write().context(spark.sqlContext()).overwrite().save(outputPath); MyParams newInstance = MyParams.load(outputPath); Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); Assert.assertEquals("Params should be preserved.", diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index e8e60c64121b5f2c558821f481ffc36e26a45b34..486733a390a0c6c3d79c59e56cb47b8477067b9c 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -34,7 +34,10 @@ __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] class SQLContext(object): - """Wrapper around :class:`SparkSession`, the main entry point to Spark SQL functionality. + """The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x. + + As of Spark 2.0, this is replaced by :class:`SparkSession`. However, we are keeping the class + here for backward compatibility. A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as tables, execute SQL over tables, cache tables, and read parquet files. diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 257a239c8d7b03c4ea16029f3b857d6b970bc1cf..0e04b88265fa175d874aebb62a507736ec5635ef 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -120,6 +120,8 @@ class SparkSession(object): def appName(self, name): """Sets a name for the application, which will be shown in the Spark web UI. + If no application name is set, a randomly generated name will be used. + :param name: an application name """ return self.config("spark.app.name", name) @@ -133,8 +135,17 @@ class SparkSession(object): @since(2.0) def getOrCreate(self): - """Gets an existing :class:`SparkSession` or, if there is no existing one, creates a new - one based on the options set in this builder. + """Gets an existing :class:`SparkSession` or, if there is no existing one, creates a + new one based on the options set in this builder. + + This method first checks whether there is a valid thread-local SparkSession, + and if yes, return that one. It then checks whether there is a valid global + default SparkSession, and if yes, return that one. If no valid global default + SparkSession exists, the method creates a new SparkSession and assigns the + newly created SparkSession as the global default. + + In case an existing SparkSession is returned, the config options specified + in this builder will be applied to the existing SparkSession. """ with self._lock: from pyspark.conf import SparkConf @@ -175,7 +186,7 @@ class SparkSession(object): if jsparkSession is None: jsparkSession = self._jvm.SparkSession(self._jsc.sc()) self._jsparkSession = jsparkSession - self._jwrapped = self._jsparkSession.wrapped() + self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) _monkey_patch_RDD(self) install_exception_handler() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 02dd6547a4adc2cb3523176800296b2d189c8c7e..78a167eef2e4e3218fab08d6b8c2cff05bf06cdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -213,7 +213,7 @@ class Dataset[T] private[sql]( private implicit def classTag = unresolvedTEncoder.clsTag // sqlContext must be val because a stable identifier is expected when you import implicits - @transient lazy val sqlContext: SQLContext = sparkSession.wrapped + @transient lazy val sqlContext: SQLContext = sparkSession.sqlContext protected[sql] def resolve(colName: String): NamedExpression = { queryExecution.analyzed.resolveQuoted(colName, sparkSession.sessionState.analyzer.resolver) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a3e2b49556e4c3d77d3ca3b7e9c42a7f6e3047af..14d12d30bc0b377dcd70f361aa5fa47013008f36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -19,25 +19,22 @@ package org.apache.spark.sql import java.beans.BeanInfo import java.util.Properties -import java.util.concurrent.atomic.AtomicReference import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.{SparkConf, SparkContext, SparkException} +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.rdd.RDD -import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.ShowTablesCommand -import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ @@ -46,8 +43,8 @@ import org.apache.spark.sql.util.ExecutionListenerManager /** * The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x. * - * As of Spark 2.0, this is replaced by [[SparkSession]]. However, we are keeping the class here - * for backward compatibility. + * As of Spark 2.0, this is replaced by [[SparkSession]]. However, we are keeping the class + * here for backward compatibility. * * @groupname basic Basic Operations * @groupname ddl_ops Persistent Catalog DDL @@ -76,42 +73,21 @@ class SQLContext private[sql]( this(sparkSession, true) } + @deprecated("Use SparkSession.builder instead", "2.0.0") def this(sc: SparkContext) = { this(new SparkSession(sc)) } + @deprecated("Use SparkSession.builder instead", "2.0.0") def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) // TODO: move this logic into SparkSession - // If spark.sql.allowMultipleContexts is true, we will throw an exception if a user - // wants to create a new root SQLContext (a SQLContext that is not created by newSession). - private val allowMultipleContexts = - sparkContext.conf.getBoolean( - SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, - SQLConf.ALLOW_MULTIPLE_CONTEXTS.defaultValue.get) - - // Assert no root SQLContext is running when allowMultipleContexts is false. - { - if (!allowMultipleContexts && isRootContext) { - SQLContext.getInstantiatedContextOption() match { - case Some(rootSQLContext) => - val errMsg = "Only one SQLContext/HiveContext may be running in this JVM. " + - s"It is recommended to use SQLContext.getOrCreate to get the instantiated " + - s"SQLContext/HiveContext. To ignore this error, " + - s"set ${SQLConf.ALLOW_MULTIPLE_CONTEXTS.key} = true in SparkConf." - throw new SparkException(errMsg) - case None => // OK - } - } - } - protected[sql] def sessionState: SessionState = sparkSession.sessionState protected[sql] def sharedState: SharedState = sparkSession.sharedState protected[sql] def conf: SQLConf = sessionState.conf protected[sql] def runtimeConf: RuntimeConfig = sparkSession.conf protected[sql] def cacheManager: CacheManager = sparkSession.cacheManager - protected[sql] def listener: SQLListener = sparkSession.listener protected[sql] def externalCatalog: ExternalCatalog = sparkSession.externalCatalog def sparkContext: SparkContext = sparkSession.sparkContext @@ -123,7 +99,7 @@ class SQLContext private[sql]( * * @since 1.6.0 */ - def newSession(): SQLContext = sparkSession.newSession().wrapped + def newSession(): SQLContext = sparkSession.newSession().sqlContext /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s @@ -760,21 +736,6 @@ class SQLContext private[sql]( schema: StructType): DataFrame = { sparkSession.applySchemaToPythonRDD(rdd, schema) } - - // TODO: move this logic into SparkSession - - // Register a successfully instantiated context to the singleton. This should be at the end of - // the class definition so that the singleton is updated only if there is no exception in the - // construction of the instance. - sparkContext.addSparkListener(new SparkListener { - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { - SQLContext.clearInstantiatedContext() - SQLContext.clearSqlListener() - } - }) - - sparkSession.setWrappedContext(self) - SQLContext.setInstantiatedContext(self) } /** @@ -787,19 +748,6 @@ class SQLContext private[sql]( */ object SQLContext { - /** - * The active SQLContext for the current thread. - */ - private val activeContext: InheritableThreadLocal[SQLContext] = - new InheritableThreadLocal[SQLContext] - - /** - * Reference to the created SQLContext. - */ - @transient private val instantiatedContext = new AtomicReference[SQLContext]() - - @transient private val sqlListener = new AtomicReference[SQLListener]() - /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. * @@ -811,41 +759,9 @@ object SQLContext { * * @since 1.5.0 */ + @deprecated("Use SparkSession.builder instead", "2.0.0") def getOrCreate(sparkContext: SparkContext): SQLContext = { - val ctx = activeContext.get() - if (ctx != null && !ctx.sparkContext.isStopped) { - return ctx - } - - synchronized { - val ctx = instantiatedContext.get() - if (ctx == null || ctx.sparkContext.isStopped) { - new SQLContext(sparkContext) - } else { - ctx - } - } - } - - private[sql] def clearInstantiatedContext(): Unit = { - instantiatedContext.set(null) - } - - private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = { - synchronized { - val ctx = instantiatedContext.get() - if (ctx == null || ctx.sparkContext.isStopped) { - instantiatedContext.set(sqlContext) - } - } - } - - private[sql] def getInstantiatedContextOption(): Option[SQLContext] = { - Option(instantiatedContext.get()) - } - - private[sql] def clearSqlListener(): Unit = { - sqlListener.set(null) + SparkSession.builder().sparkContext(sparkContext).getOrCreate().sqlContext } /** @@ -855,8 +771,9 @@ object SQLContext { * * @since 1.6.0 */ + @deprecated("Use SparkSession.setActiveSession instead", "2.0.0") def setActive(sqlContext: SQLContext): Unit = { - activeContext.set(sqlContext) + SparkSession.setActiveSession(sqlContext.sparkSession) } /** @@ -865,12 +782,9 @@ object SQLContext { * * @since 1.6.0 */ + @deprecated("Use SparkSession.clearActiveSession instead", "2.0.0") def clearActive(): Unit = { - activeContext.remove() - } - - private[sql] def getActive(): Option[SQLContext] = { - Option(activeContext.get()) + SparkSession.clearActiveSession() } /** @@ -894,20 +808,6 @@ object SQLContext { } } - /** - * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. - */ - private[sql] def createListenerAndUI(sc: SparkContext): SQLListener = { - if (sqlListener.get() == null) { - val listener = new SQLListener(sc.conf) - if (sqlListener.compareAndSet(null, listener)) { - sc.addSparkListener(listener) - sc.ui.foreach(new SQLTab(listener, _)) - } - } - sqlListener.get() - } - /** * Extract `spark.sql.*` properties from the conf and return them as a [[Properties]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 629243bd1a97b791d643aa575283346d03b1ef2c..f697769bdcdb5f35bca4499e588e213a1e26b9ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.beans.Introspector +import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -30,6 +31,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ @@ -98,24 +100,10 @@ class SparkSession private( } /** - * A wrapped version of this session in the form of a [[SQLContext]]. + * A wrapped version of this session in the form of a [[SQLContext]], for backward compatibility. */ @transient - private var _wrapped: SQLContext = _ - - @transient - private val _wrappedLock = new Object - - protected[sql] def wrapped: SQLContext = _wrappedLock.synchronized { - if (_wrapped == null) { - _wrapped = new SQLContext(self, isRootContext = false) - } - _wrapped - } - - protected[sql] def setWrappedContext(sqlContext: SQLContext): Unit = _wrappedLock.synchronized { - _wrapped = sqlContext - } + private[sql] val sqlContext: SQLContext = new SQLContext(this) protected[sql] def cacheManager: CacheManager = sharedState.cacheManager protected[sql] def listener: SQLListener = sharedState.listener @@ -238,7 +226,7 @@ class SparkSession private( */ @Experimental def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { - SQLContext.setActive(wrapped) + SparkSession.setActiveSession(this) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) @@ -254,7 +242,7 @@ class SparkSession private( */ @Experimental def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { - SQLContext.setActive(wrapped) + SparkSession.setActiveSession(this) val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data)) @@ -573,7 +561,7 @@ class SparkSession private( */ @Experimental object implicits extends SQLImplicits with Serializable { - protected override def _sqlContext: SQLContext = wrapped + protected override def _sqlContext: SQLContext = SparkSession.this.sqlContext } // scalastyle:on @@ -649,8 +637,16 @@ object SparkSession { private[this] val options = new scala.collection.mutable.HashMap[String, String] + private[this] var userSuppliedContext: Option[SparkContext] = None + + private[sql] def sparkContext(sparkContext: SparkContext): Builder = synchronized { + userSuppliedContext = Option(sparkContext) + this + } + /** * Sets a name for the application, which will be shown in the Spark web UI. + * If no application name is set, a randomly generated name will be used. * * @since 2.0.0 */ @@ -735,29 +731,130 @@ object SparkSession { } /** - * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new one - * based on the options set in this builder. + * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new + * one based on the options set in this builder. + * + * This method first checks whether there is a valid thread-local SparkSession, + * and if yes, return that one. It then checks whether there is a valid global + * default SparkSession, and if yes, return that one. If no valid global default + * SparkSession exists, the method creates a new SparkSession and assigns the + * newly created SparkSession as the global default. + * + * In case an existing SparkSession is returned, the config options specified in + * this builder will be applied to the existing SparkSession. * * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { - // Step 1. Create a SparkConf - // Step 2. Get a SparkContext - // Step 3. Get a SparkSession - val sparkConf = new SparkConf() - options.foreach { case (k, v) => sparkConf.set(k, v) } - val sparkContext = SparkContext.getOrCreate(sparkConf) - - SQLContext.getOrCreate(sparkContext).sparkSession + // Get the session from current thread's active session. + var session = activeThreadSession.get() + if ((session ne null) && !session.sparkContext.isStopped) { + options.foreach { case (k, v) => session.conf.set(k, v) } + return session + } + + // Global synchronization so we will only set the default session once. + SparkSession.synchronized { + // If the current thread does not have an active session, get it from the global session. + session = defaultSession.get() + if ((session ne null) && !session.sparkContext.isStopped) { + options.foreach { case (k, v) => session.conf.set(k, v) } + return session + } + + // No active nor global default session. Create a new one. + val sparkContext = userSuppliedContext.getOrElse { + // set app name if not given + if (!options.contains("spark.app.name")) { + options += "spark.app.name" -> java.util.UUID.randomUUID().toString + } + + val sparkConf = new SparkConf() + options.foreach { case (k, v) => sparkConf.set(k, v) } + SparkContext.getOrCreate(sparkConf) + } + session = new SparkSession(sparkContext) + options.foreach { case (k, v) => session.conf.set(k, v) } + defaultSession.set(session) + + // Register a successfully instantiated context to the singleton. This should be at the + // end of the class definition so that the singleton is updated only if there is no + // exception in the construction of the instance. + sparkContext.addSparkListener(new SparkListener { + override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { + defaultSession.set(null) + sqlListener.set(null) + } + }) + } + + return session } } /** * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]]. + * * @since 2.0.0 */ def builder(): Builder = new Builder + /** + * Changes the SparkSession that will be returned in this thread and its children when + * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives + * a SparkSession with an isolated session, instead of the global (first created) context. + * + * @since 2.0.0 + */ + def setActiveSession(session: SparkSession): Unit = { + activeThreadSession.set(session) + } + + /** + * Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will + * return the first created context instead of a thread-local override. + * + * @since 2.0.0 + */ + def clearActiveSession(): Unit = { + activeThreadSession.remove() + } + + /** + * Sets the default SparkSession that is returned by the builder. + * + * @since 2.0.0 + */ + def setDefaultSession(session: SparkSession): Unit = { + defaultSession.set(session) + } + + /** + * Clears the default SparkSession that is returned by the builder. + * + * @since 2.0.0 + */ + def clearDefaultSession(): Unit = { + defaultSession.set(null) + } + + private[sql] def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + + private[sql] def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + + /** A global SQL listener used for the SQL UI. */ + private[sql] val sqlListener = new AtomicReference[SQLListener]() + + //////////////////////////////////////////////////////////////////////////////////////// + // Private methods from now on + //////////////////////////////////////////////////////////////////////////////////////// + + /** The active SparkSession for the current thread. */ + private val activeThreadSession = new InheritableThreadLocal[SparkSession] + + /** Reference to the root SparkSession. */ + private val defaultSession = new AtomicReference[SparkSession] + private val HIVE_SHARED_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSharedState" private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 85af4faf4d090ab46719d14a7ac1b3eee8706d0d..d8911f88b000499276476d5931acce0cf12b88ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -157,7 +157,8 @@ private[sql] case class RowDataSourceScanExec( val outputUnsafeRows = relation match { case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] => - !SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) + !SparkSession.getActiveSession.get.sessionState.conf.getConf( + SQLConf.PARQUET_VECTORIZED_READER_ENABLED) case _: HadoopFsRelation => true case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index cb3c46a98bfb40e57a27a5b4e8f7e492985da1d4..34187b9a1ae7f731ed0adacc0b5a995347c19f30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -60,7 +60,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { } lazy val analyzed: LogicalPlan = { - SQLContext.setActive(sparkSession.wrapped) + SparkSession.setActiveSession(sparkSession) sparkSession.sessionState.analyzer.execute(logical) } @@ -73,7 +73,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { lazy val optimizedPlan: LogicalPlan = sparkSession.sessionState.optimizer.execute(withCachedData) lazy val sparkPlan: SparkPlan = { - SQLContext.setActive(sparkSession.wrapped) + SparkSession.setActiveSession(sparkSession) planner.plan(ReturnAnswer(optimizedPlan)).next() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b94b84d77a50272ac1f8aebe7d08fff16beaf213..045ccc7bd6eae3838fb3a9bfdef60eee3bcfcfd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -27,7 +27,7 @@ import org.apache.spark.{broadcast, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd.{RDD, RDDOperationScope} -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{Row, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -50,7 +50,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * populated by the query planning infrastructure. */ @transient - protected[spark] final val sqlContext = SQLContext.getActive().orNull + final val sqlContext = SparkSession.getActiveSession.map(_.sqlContext).orNull protected def sparkContext = sqlContext.sparkContext @@ -65,7 +65,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Overridden make copy also propagates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { - SQLContext.setActive(sqlContext) + SparkSession.setActiveSession(sqlContext.sparkSession) super.makeCopy(newArgs) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index ccad9b3fd52fd6f475edc51d4e3b58bc9129b3ef..2e17b763a5370ad72e2da4b84ba0868ec2c86d4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -178,7 +178,7 @@ case class DataSource( providingClass.newInstance() match { case s: StreamSourceProvider => val (name, schema) = s.sourceSchema( - sparkSession.wrapped, userSpecifiedSchema, className, options) + sparkSession.sqlContext, userSpecifiedSchema, className, options) SourceInfo(name, schema) case format: FileFormat => @@ -198,7 +198,8 @@ case class DataSource( def createSource(metadataPath: String): Source = { providingClass.newInstance() match { case s: StreamSourceProvider => - s.createSource(sparkSession.wrapped, metadataPath, userSpecifiedSchema, className, options) + s.createSource( + sparkSession.sqlContext, metadataPath, userSpecifiedSchema, className, options) case format: FileFormat => val path = new CaseInsensitiveMap(options).getOrElse("path", { @@ -215,7 +216,7 @@ case class DataSource( /** Returns a sink that can be used to continually write data. */ def createSink(): Sink = { providingClass.newInstance() match { - case s: StreamSinkProvider => s.createSink(sparkSession.wrapped, options, partitionColumns) + case s: StreamSinkProvider => s.createSink(sparkSession.sqlContext, options, partitionColumns) case parquet: parquet.DefaultSource => val caseInsensitiveOptions = new CaseInsensitiveMap(options) @@ -265,9 +266,9 @@ case class DataSource( val relation = (providingClass.newInstance(), userSpecifiedSchema) match { // TODO: Throw when too much is given. case (dataSource: SchemaRelationProvider, Some(schema)) => - dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions, schema) + dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema) case (dataSource: RelationProvider, None) => - dataSource.createRelation(sparkSession.wrapped, caseInsensitiveOptions) + dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions) case (_: SchemaRelationProvider, None) => throw new AnalysisException(s"A schema needs to be specified when using $className.") case (_: RelationProvider, Some(_)) => @@ -383,7 +384,7 @@ case class DataSource( providingClass.newInstance() match { case dataSource: CreatableRelationProvider => - dataSource.createRelation(sparkSession.wrapped, mode, options, data) + dataSource.createRelation(sparkSession.sqlContext, mode, options, data) case format: FileFormat => // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala index 8d332df029166a0a6cb059035a7d1aef99aa401b..88125a2b4da787edb898fc996d30e6fd578db731 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala @@ -142,7 +142,7 @@ case class HadoopFsRelation( fileFormat: FileFormat, options: Map[String, String]) extends BaseRelation with FileRelation { - override def sqlContext: SQLContext = sparkSession.wrapped + override def sqlContext: SQLContext = sparkSession.sqlContext val schema: StructType = { val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index bcf70fdc4a497fccd8480e6f9761ee8a76a98723..233b7891d664c3875e380309b2143c8aabe94420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -92,7 +92,7 @@ private[sql] case class JDBCRelation( with PrunedFilteredScan with InsertableRelation { - override def sqlContext: SQLContext = sparkSession.wrapped + override def sqlContext: SQLContext = sparkSession.sqlContext override val needConversion: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index df6304d85fe7458efb8f56624359d9a1aa093706..7d09bdcebdc3d0c96cc2109b4f19019dc813cb68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -173,7 +173,7 @@ class StreamExecution( startLatch.countDown() // While active, repeatedly attempt to run batches. - SQLContext.setActive(sparkSession.wrapped) + SparkSession.setActiveSession(sparkSession) triggerExecutor.execute(() => { if (isActive) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 65bc043076759ceee506b3f86836ca9b8cc3164d..0b490fe71c526cda6398f6f6953f451d562e0e43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1168,7 +1168,7 @@ object functions { * @group normal_funcs */ def expr(expr: String): Column = { - val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse { + val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { new SparkSqlParser(new SQLConf) } Column(parser.parseExpression(expr)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5d1868980163db894db5fe085d00c2905bb68563..35d67ca2d8c5d80e8d75b35abd886840fce0a0c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -70,16 +70,6 @@ object SQLConf { .intConf .createWithDefault(10) - val ALLOW_MULTIPLE_CONTEXTS = SQLConfigBuilder("spark.sql.allowMultipleContexts") - .doc("When set to true, creating multiple SQLContexts/HiveContexts is allowed. " + - "When set to false, only one SQLContext/HiveContext is allowed to be created " + - "through the constructor (new SQLContexts/HiveContexts created through newSession " + - "method is allowed). Please note that this conf needs to be set in Spark Conf. Once " + - "a SQLContext/HiveContext has been created, changing the value of this conf will not " + - "have effect.") - .booleanConf - .createWithDefault(true) - val COMPRESS_CACHED = SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.compressed") .internal() .doc("When set to true Spark SQL will automatically select a compression codec for each " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index eaf993aaed4d2c725a4e8d20c9c3816e262dc39a..9f6137d6e3c7c3d9a7380622e7828ca8853b8884 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkContext -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog} import org.apache.spark.sql.execution.CacheManager -import org.apache.spark.sql.execution.ui.SQLListener +import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.util.MutableURLClassLoader @@ -38,7 +38,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) { /** * A listener for SQL-specific [[org.apache.spark.scheduler.SparkListenerEvent]]s. */ - val listener: SQLListener = SQLContext.createListenerAndUI(sparkContext) + val listener: SQLListener = createListenerAndUI(sparkContext) /** * A catalog that interacts with external systems. @@ -51,6 +51,19 @@ private[sql] class SharedState(val sparkContext: SparkContext) { val jarClassLoader = new NonClosableMutableURLClassLoader( org.apache.spark.util.Utils.getContextOrSparkClassLoader) + /** + * Create a SQLListener then add it into SparkContext, and create an SQLTab if there is SparkUI. + */ + private def createListenerAndUI(sc: SparkContext): SQLListener = { + if (SparkSession.sqlListener.get() == null) { + val listener = new SQLListener(sc.conf) + if (SparkSession.sqlListener.compareAndSet(null, listener)) { + sc.addSparkListener(listener) + sc.ui.foreach(new SQLTab(listener, _)) + } + } + SparkSession.sqlListener.get() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 65fe271b69172efc5ece6e7c2a33d690150d3c2d..b447006761f45638b5cb0127dcb12bb96cd4118a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -39,7 +39,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex test("get all tables") { checkAnswer( - spark.wrapped.tables().filter("tableName = 'listtablessuitetable'"), + spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) checkAnswer( @@ -48,12 +48,12 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex spark.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) - assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0) + assert(spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0) } test("getting all tables with a database name has no impact on returned table names") { checkAnswer( - spark.wrapped.tables("default").filter("tableName = 'listtablessuitetable'"), + spark.sqlContext.tables("default").filter("tableName = 'listtablessuitetable'"), Row("listtablessuitetable", true)) checkAnswer( @@ -62,7 +62,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex spark.sessionState.catalog.dropTable( TableIdentifier("listtablessuitetable"), ignoreIfNotExists = true) - assert(spark.wrapped.tables().filter("tableName = 'listtablessuitetable'").count() === 0) + assert(spark.sqlContext.tables().filter("tableName = 'listtablessuitetable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -70,7 +70,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(spark.wrapped.tables(), sql("SHOW TABLes")).foreach { + Seq(spark.sqlContext.tables(), sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) @@ -81,7 +81,8 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContex Row(true, "listtablessuitetable") ) checkAnswer( - spark.wrapped.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + spark.sqlContext.tables() + .filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) spark.catalog.dropTempView("tables") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala deleted file mode 100644 index 0b5a92c256e578e6dadd6ce9566b73b7c863e6bd..0000000000000000000000000000000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark._ -import org.apache.spark.sql.internal.SQLConf - -class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { - - private var originalActiveSQLContext: Option[SQLContext] = _ - private var originalInstantiatedSQLContext: Option[SQLContext] = _ - private var sparkConf: SparkConf = _ - - override protected def beforeAll(): Unit = { - originalActiveSQLContext = SQLContext.getActive() - originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() - - SQLContext.clearActive() - SQLContext.clearInstantiatedContext() - sparkConf = - new SparkConf(false) - .setMaster("local[*]") - .setAppName("test") - .set("spark.ui.enabled", "false") - .set("spark.driver.allowMultipleContexts", "true") - } - - override protected def afterAll(): Unit = { - // Set these states back. - originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx)) - originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx)) - } - - def testNewSession(rootSQLContext: SQLContext): Unit = { - // Make sure we can successfully create new Session. - rootSQLContext.newSession() - - // Reset the state. It is always safe to clear the active context. - SQLContext.clearActive() - } - - def testCreatingNewSQLContext(allowsMultipleContexts: Boolean): Unit = { - val conf = - sparkConf - .clone - .set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowsMultipleContexts.toString) - val sparkContext = new SparkContext(conf) - - try { - if (allowsMultipleContexts) { - new SQLContext(sparkContext) - SQLContext.clearActive() - } else { - // If allowsMultipleContexts is false, make sure we can get the error. - val message = intercept[SparkException] { - new SQLContext(sparkContext) - }.getMessage - assert(message.contains("Only one SQLContext/HiveContext may be running")) - } - } finally { - sparkContext.stop() - } - } - - test("test the flag to disallow creating multiple root SQLContext") { - Seq(false, true).foreach { allowMultipleSQLContexts => - val conf = - sparkConf - .clone - .set(SQLConf.ALLOW_MULTIPLE_CONTEXTS.key, allowMultipleSQLContexts.toString) - val sc = new SparkContext(conf) - try { - val rootSQLContext = new SQLContext(sc) - testNewSession(rootSQLContext) - testNewSession(rootSQLContext) - testCreatingNewSQLContext(allowMultipleSQLContexts) - } finally { - sc.stop() - SQLContext.clearInstantiatedContext() - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 38d7b6e25b829369f04964c73064bc26eabca17f..c9594a7e9ab2865370685c2f9b044e47e2ad798c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -40,7 +40,7 @@ class SQLContextSuite extends SparkFunSuite with SharedSparkContext { val newSession = sqlContext.newSession() assert(SQLContext.getOrCreate(sc).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") - SQLContext.setActive(newSession) + SparkSession.setActiveSession(newSession.sparkSession) assert(SQLContext.getOrCreate(sc).eq(newSession), "SQLContext.getOrCreate after explicitly setActive() did not return the active context") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 743a27aa7a21ddcb36649e8145bfd0067d88370e..460e34a5ff3082890fc195207949d74f38ce2d58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1042,7 +1042,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SET commands semantics using sql()") { - spark.wrapped.conf.clear() + spark.sqlContext.conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" @@ -1083,17 +1083,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { sql(s"SET $nonexistentKey"), Row(nonexistentKey, "<undefined>") ) - spark.wrapped.conf.clear() + spark.sqlContext.conf.clear() } test("SET commands with illegal or inappropriate argument") { - spark.wrapped.conf.clear() + spark.sqlContext.conf.clear() // Set negative mapred.reduce.tasks for automatically determining // the number of reducers is not supported intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2")) - spark.wrapped.conf.clear() + spark.sqlContext.conf.clear() } test("apply schema") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index b489b74fec074cfbc37cad9795a4fd36c11e6d63..cd6b2647e0be6c70fc59cd20a7796cbed106657a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -25,6 +25,6 @@ class SerializationSuite extends SparkFunSuite with SharedSQLContext { test("[SPARK-5235] SQLContext should be serializable") { val spark = SparkSession.builder.getOrCreate() - new JavaSerializer(new SparkConf()).newInstance().serialize(spark.wrapped) + new JavaSerializer(new SparkConf()).newInstance().serialize(spark.sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..ec6a2b3575869f40aaad5ce781716f6303e35088 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.{SparkContext, SparkFunSuite} + +/** + * Test cases for the builder pattern of [[SparkSession]]. + */ +class SparkSessionBuilderSuite extends SparkFunSuite { + + private var initialSession: SparkSession = _ + + private lazy val sparkContext: SparkContext = { + initialSession = SparkSession.builder() + .master("local") + .config("spark.ui.enabled", value = false) + .config("some-config", "v2") + .getOrCreate() + initialSession.sparkContext + } + + test("create with config options and propagate them to SparkContext and SparkSession") { + // Creating a new session with config - this works by just calling the lazy val + sparkContext + assert(initialSession.sparkContext.conf.get("some-config") == "v2") + assert(initialSession.conf.get("some-config") == "v2") + SparkSession.clearDefaultSession() + } + + test("use global default session") { + val session = SparkSession.builder().getOrCreate() + assert(SparkSession.builder().getOrCreate() == session) + SparkSession.clearDefaultSession() + } + + test("config options are propagated to existing SparkSession") { + val session1 = SparkSession.builder().config("spark-config1", "a").getOrCreate() + assert(session1.conf.get("spark-config1") == "a") + val session2 = SparkSession.builder().config("spark-config1", "b").getOrCreate() + assert(session1 == session2) + assert(session1.conf.get("spark-config1") == "b") + SparkSession.clearDefaultSession() + } + + test("use session from active thread session and propagate config options") { + val defaultSession = SparkSession.builder().getOrCreate() + val activeSession = defaultSession.newSession() + SparkSession.setActiveSession(activeSession) + val session = SparkSession.builder().config("spark-config2", "a").getOrCreate() + + assert(activeSession != defaultSession) + assert(session == activeSession) + assert(session.conf.get("spark-config2") == "a") + SparkSession.clearActiveSession() + + assert(SparkSession.builder().getOrCreate() == defaultSession) + SparkSession.clearDefaultSession() + } + + test("create a new session if the default session has been stopped") { + val defaultSession = SparkSession.builder().getOrCreate() + SparkSession.setDefaultSession(defaultSession) + defaultSession.stop() + val newSession = SparkSession.builder().master("local").getOrCreate() + assert(newSession != defaultSession) + newSession.stop() + } + + test("create a new session if the active thread session has been stopped") { + val activeSession = SparkSession.builder().master("local").getOrCreate() + SparkSession.setActiveSession(activeSession) + activeSession.stop() + val newSession = SparkSession.builder().master("local").getOrCreate() + assert(newSession != activeSession) + newSession.stop() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala index 9523f6f9f5bbf20734dd19fe6ad62656bfe3e39c..4de3cf605caa17bc01ad39112980f7f42b75c279 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsSuite.scala @@ -26,9 +26,9 @@ class StatisticsSuite extends QueryTest with SharedSQLContext { val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) assert(df.queryExecution.analyzed.statistics.sizeInBytes > - spark.wrapped.conf.autoBroadcastJoinThreshold) + spark.sessionState.conf.autoBroadcastJoinThreshold) assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes > - spark.wrapped.conf.autoBroadcastJoinThreshold) + spark.sessionState.conf.autoBroadcastJoinThreshold) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 70a00a43f7db41fb724b4de353f9d1c219b24cb2..2f45db3925a00f9d96a8239d573886bf22ec2388 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -27,21 +27,21 @@ import org.apache.spark.sql.internal.SQLConf class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { - private var originalActiveSQLContext: Option[SQLContext] = _ - private var originalInstantiatedSQLContext: Option[SQLContext] = _ + private var originalActiveSQLContext: Option[SparkSession] = _ + private var originalInstantiatedSQLContext: Option[SparkSession] = _ override protected def beforeAll(): Unit = { - originalActiveSQLContext = SQLContext.getActive() - originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption() + originalActiveSQLContext = SparkSession.getActiveSession + originalInstantiatedSQLContext = SparkSession.getDefaultSession - SQLContext.clearActive() - SQLContext.clearInstantiatedContext() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } override protected def afterAll(): Unit = { // Set these states back. - originalActiveSQLContext.foreach(ctx => SQLContext.setActive(ctx)) - originalInstantiatedSQLContext.foreach(ctx => SQLContext.setInstantiatedContext(ctx)) + originalActiveSQLContext.foreach(ctx => SparkSession.setActiveSession(ctx)) + originalInstantiatedSQLContext.foreach(ctx => SparkSession.setDefaultSession(ctx)) } private def checkEstimation( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 2a5295d0d223103c36a21be0925e5e5b5debd0eb..8243470b19334de0c6d2c8f8e9a0450bed4824d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -155,7 +155,7 @@ class PlannerSuite extends SharedSQLContext { val path = file.getCanonicalPath testData.write.parquet(path) val df = spark.read.parquet(path) - spark.wrapped.registerDataFrameAsTable(df, "testPushed") + spark.sqlContext.registerDataFrameAsTable(df, "testPushed") withTempTable("testPushed") { val exp = sql("select * from testPushed where key = 15").queryExecution.sparkPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index d7eae21f9f556bb2943bc6c01dcda842bb24799a..9fe0e9646e31e6c4a3d0a87164709e985e172a5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -91,7 +91,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { SparkPlanTest - .checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.wrapped) match { + .checkAnswer(input, planFunction, expectedAnswer, sortAnswers, spark.sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -115,7 +115,7 @@ private[sql] abstract class SparkPlanTest extends SparkFunSuite { expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { SparkPlanTest.checkAnswer( - input, planFunction, expectedPlanFunction, sortAnswers, spark.wrapped) match { + input, planFunction, expectedPlanFunction, sortAnswers, spark.sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index b5fc51603e1609967af4cdc3a8ac81fdc6a0176c..1753b84ba6af6ca401b7eb131d53318939b7649f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -90,7 +90,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T], tableName: String, testVectorized: Boolean = true) (f: => Unit): Unit = { withParquetDataFrame(data, testVectorized) { df => - spark.wrapped.registerDataFrameAsTable(df, tableName) + spark.sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 4fa1754253aff61d9d3c25437c73f6328154314e..bd197be655d5862337c8305b13e0e615dbd10d97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -60,13 +60,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( - spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)( + spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( - spark.wrapped, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + spark.sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -82,7 +82,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn spark: SparkSession, seq: Seq[String], storeVersion: Int): RDD[(String, Int)] = { - implicit val sqlContext = spark.wrapped + implicit val sqlContext = spark.sqlContext makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment) } @@ -102,7 +102,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("usage with iterators - only gets and only puts") { withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => - implicit val sqlContext = spark.wrapped + implicit val sqlContext = spark.sqlContext val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 @@ -131,7 +131,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( - spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) + spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( @@ -150,7 +150,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => - implicit val sqlContext = spark.wrapped + implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") @@ -183,7 +183,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn SparkSession.builder .config(sparkConf.setMaster("local-cluster[2, 1, 1024]")) .getOrCreate()) { spark => - implicit val sqlContext = spark.wrapped + implicit val sqlContext = spark.sqlContext val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 1c467137baa86270e085ba50c50821e49d8222c1..2374ffaaa5036b605ad31947943613c8326a78fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -24,7 +24,7 @@ import org.mockito.Mockito.mock import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -400,8 +400,8 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly val sc = new SparkContext(conf) try { - SQLContext.clearSqlListener() - val spark = new SQLContext(sc) + SparkSession.sqlListener.set(null) + val spark = new SparkSession(sc) import spark.implicits._ // Run 100 successful executions and 100 failed executions. // Each execution only has one job and one stage. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 81bc973be74a3ed9edcf6913bd637e51f312a95c..0296229100a24871a7fc4d46bab970da00987eb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -35,7 +35,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { // Set a conf first. spark.conf.set(testKey, testVal) // Clear the conf. - spark.wrapped.conf.clear() + spark.sqlContext.conf.clear() // After clear, only overrideConfs used by unit test should be in the SQLConf. assert(spark.conf.getAll === TestSQLContext.overrideConfs) @@ -50,11 +50,11 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { assert(spark.conf.get(testKey, testVal + "_") === testVal) assert(spark.conf.getAll.contains(testKey)) - spark.wrapped.conf.clear() + spark.sqlContext.conf.clear() } test("parse SQL set commands") { - spark.wrapped.conf.clear() + spark.sqlContext.conf.clear() sql(s"set $testKey=$testVal") assert(spark.conf.get(testKey, testVal + "_") === testVal) assert(spark.conf.get(testKey, testVal + "_") === testVal) @@ -72,11 +72,11 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { sql(s"set $key=") assert(spark.conf.get(key, "0") === "") - spark.wrapped.conf.clear() + spark.sqlContext.conf.clear() } test("set command for display") { - spark.wrapped.conf.clear() + spark.sessionState.conf.clear() checkAnswer( sql("SET").where("key = 'spark.sql.groupByOrdinal'").select("key", "value"), Nil) @@ -97,7 +97,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } test("deprecated property") { - spark.wrapped.conf.clear() + spark.sqlContext.conf.clear() val original = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) try{ sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") @@ -108,7 +108,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } test("invalid conf value") { - spark.wrapped.conf.clear() + spark.sqlContext.conf.clear() val e = intercept[IllegalArgumentException] { sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } @@ -116,7 +116,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } test("Test SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE's method") { - spark.wrapped.conf.clear() + spark.sqlContext.conf.clear() spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100") assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 100) @@ -144,7 +144,7 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g") } - spark.wrapped.conf.clear() + spark.sqlContext.conf.clear() } test("SparkSession can access configs set in SparkConf") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 612cfc7ec7bd6d3409b8d9aa39bed861792bffd3..a34f70ed65b5c0aa01e28158e4ff4ef11c60233e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -41,7 +41,7 @@ case class SimpleDDLScan( table: String)(@transient val sparkSession: SparkSession) extends BaseRelation with TableScan { - override def sqlContext: SQLContext = sparkSession.wrapped + override def sqlContext: SQLContext = sparkSession.sqlContext override def schema: StructType = StructType(Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 51d04f2f4efc644a5fc91f60029c7fbe84ed3dd2..f969660ddd322ae28ce4c250768c9ab5e288f0b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -40,7 +40,7 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sparkSession: S extends BaseRelation with PrunedFilteredScan { - override def sqlContext: SQLContext = sparkSession.wrapped + override def sqlContext: SQLContext = sparkSession.sqlContext override def schema: StructType = StructType( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index cd0256db43aad3885e1b79587e543438b34ae92d..9cdf7dea7663e2403a3663fc43dfa13a23664e14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -37,7 +37,7 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sparkSession: Spa extends BaseRelation with PrunedScan { - override def sqlContext: SQLContext = sparkSession.wrapped + override def sqlContext: SQLContext = sparkSession.sqlContext override def schema: StructType = StructType( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 34b8726a922fbcf1bd3b3543660ea033dbb62b7f..cddf4a1884fa8063935e4e35146c28f00bd5ab58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -38,7 +38,7 @@ class SimpleScanSource extends RelationProvider { case class SimpleScan(from: Int, to: Int)(@transient val sparkSession: SparkSession) extends BaseRelation with TableScan { - override def sqlContext: SQLContext = sparkSession.wrapped + override def sqlContext: SQLContext = sparkSession.sqlContext override def schema: StructType = StructType(StructField("i", IntegerType, nullable = false) :: Nil) @@ -70,7 +70,7 @@ case class AllDataTypesScan( extends BaseRelation with TableScan { - override def sqlContext: SQLContext = sparkSession.wrapped + override def sqlContext: SQLContext = sparkSession.sqlContext override def schema: StructType = userSpecifiedSchema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index ff535055493306946f716414891562d9226c08ca..e6c0ce95e7b5729ab336177a9434a2cf61e6a843 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -355,14 +355,14 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B q.stop() verify(LastOptions.mockStreamSourceProvider).createSource( - spark.wrapped, + spark.sqlContext, checkpointLocation + "/sources/0", None, "org.apache.spark.sql.streaming.test", Map.empty) verify(LastOptions.mockStreamSourceProvider).createSource( - spark.wrapped, + spark.sqlContext, checkpointLocation + "/sources/1", None, "org.apache.spark.sql.streaming.test", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 421f6bca7f865b352492de459c058edcbb4e29ba..0cfe260e52152f609b0de2952e09ff1c88d2aa1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -30,7 +30,7 @@ private[sql] trait SQLTestData { self => // Helper object to import SQL implicits without a concrete SQLContext private object internalImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self.spark.wrapped + protected override def _sqlContext: SQLContext = self.spark.sqlContext } import internalImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 51538eca644f6a92f0dab417688c0821e8258974..853dd0ff3f60103159dd372131b6a65473e17bab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -66,7 +66,7 @@ private[sql] trait SQLTestUtils * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self.spark.wrapped + protected override def _sqlContext: SQLContext = self.spark.sqlContext } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 620bfa995aa20eee34120cea0fd941ec8a12220e..79c37faa4e9ba9d21a8cfa516f1ac65f6a108ea3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -44,13 +44,13 @@ trait SharedSQLContext extends SQLTestUtils { /** * The [[TestSQLContext]] to use for all tests in this suite. */ - protected implicit def sqlContext: SQLContext = _spark.wrapped + protected implicit def sqlContext: SQLContext = _spark.sqlContext /** * Initialize the [[TestSparkSession]]. */ protected override def beforeAll(): Unit = { - SQLContext.clearSqlListener() + SparkSession.sqlListener.set(null) if (_spark == null) { _spark = new TestSparkSession(sparkConf) } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 8de223f444f70f2519d65fe0907e87b4428248a3..638911599aad3b28c6359286e4c1fc2bebee11f0 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -56,7 +56,7 @@ private[hive] object SparkSQLEnv extends Logging { val sparkSession = SparkSession.builder.config(sparkConf).enableHiveSupport().getOrCreate() sparkContext = sparkSession.sparkContext - sqlContext = sparkSession.wrapped + sqlContext = sparkSession.sqlContext val sessionState = sparkSession.sessionState.asInstanceOf[HiveSessionState] sessionState.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index d2cb62c617d4413f1bbf5cfe9f9a5625fe9b940a..7c74a0308d4833a4857ed874b13a758c04672d99 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -30,7 +30,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd override protected def beforeEach(): Unit = { super.beforeEach() - if (spark.wrapped.tableNames().contains("src")) { + if (spark.sqlContext.tableNames().contains("src")) { spark.catalog.dropTempView("src") } Seq((1, "")).toDF("key", "value").createOrReplaceTempView("src") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 6c9ce208dbd6ae46de51165609daeaee7ca817af..622b043581c5d606ae408e08d44ea6c868b4ce26 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -36,11 +36,11 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(spark.wrapped.tableNames().contains("t")) + assert(spark.sqlContext.tableNames().contains("t")) checkAnswer(spark.table("t"), df) } - assert(spark.wrapped.tableNames(db).contains("t")) + assert(spark.sqlContext.tableNames(db).contains("t")) checkAnswer(spark.table(s"$db.t"), df) checkTablePath(db, "t") @@ -50,7 +50,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle test(s"saveAsTable() to non-default database - without USE - Overwrite") { withTempDatabase { db => df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") - assert(spark.wrapped.tableNames(db).contains("t")) + assert(spark.sqlContext.tableNames(db).contains("t")) checkAnswer(spark.table(s"$db.t"), df) checkTablePath(db, "t") @@ -65,7 +65,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle df.write.format("parquet").mode(SaveMode.Overwrite).save(path) spark.catalog.createExternalTable("t", path, "parquet") - assert(spark.wrapped.tableNames(db).contains("t")) + assert(spark.sqlContext.tableNames(db).contains("t")) checkAnswer(spark.table("t"), df) sql( @@ -76,7 +76,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle | path '$path' |) """.stripMargin) - assert(spark.wrapped.tableNames(db).contains("t1")) + assert(spark.sqlContext.tableNames(db).contains("t1")) checkAnswer(spark.table("t1"), df) } } @@ -90,7 +90,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle df.write.format("parquet").mode(SaveMode.Overwrite).save(path) spark.catalog.createExternalTable(s"$db.t", path, "parquet") - assert(spark.wrapped.tableNames(db).contains("t")) + assert(spark.sqlContext.tableNames(db).contains("t")) checkAnswer(spark.table(s"$db.t"), df) sql( @@ -101,7 +101,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle | path '$path' |) """.stripMargin) - assert(spark.wrapped.tableNames(db).contains("t1")) + assert(spark.sqlContext.tableNames(db).contains("t1")) checkAnswer(spark.table(s"$db.t1"), df) } } @@ -112,11 +112,11 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") df.write.mode(SaveMode.Append).saveAsTable("t") - assert(spark.wrapped.tableNames().contains("t")) + assert(spark.sqlContext.tableNames().contains("t")) checkAnswer(spark.table("t"), df.union(df)) } - assert(spark.wrapped.tableNames(db).contains("t")) + assert(spark.sqlContext.tableNames(db).contains("t")) checkAnswer(spark.table(s"$db.t"), df.union(df)) checkTablePath(db, "t") @@ -127,7 +127,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") - assert(spark.wrapped.tableNames(db).contains("t")) + assert(spark.sqlContext.tableNames(db).contains("t")) checkAnswer(spark.table(s"$db.t"), df.union(df)) checkTablePath(db, "t") @@ -138,7 +138,7 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(spark.wrapped.tableNames().contains("t")) + assert(spark.sqlContext.tableNames().contains("t")) df.write.insertInto(s"$db.t") checkAnswer(spark.table(s"$db.t"), df.union(df)) @@ -150,10 +150,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(spark.wrapped.tableNames().contains("t")) + assert(spark.sqlContext.tableNames().contains("t")) } - assert(spark.wrapped.tableNames(db).contains("t")) + assert(spark.sqlContext.tableNames(db).contains("t")) df.write.insertInto(s"$db.t") checkAnswer(spark.table(s"$db.t"), df.union(df)) @@ -175,21 +175,21 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { sql(s"CREATE TABLE t (key INT)") - assert(spark.wrapped.tableNames().contains("t")) - assert(!spark.wrapped.tableNames("default").contains("t")) + assert(spark.sqlContext.tableNames().contains("t")) + assert(!spark.sqlContext.tableNames("default").contains("t")) } - assert(!spark.wrapped.tableNames().contains("t")) - assert(spark.wrapped.tableNames(db).contains("t")) + assert(!spark.sqlContext.tableNames().contains("t")) + assert(spark.sqlContext.tableNames(db).contains("t")) activateDatabase(db) { sql(s"DROP TABLE t") - assert(!spark.wrapped.tableNames().contains("t")) - assert(!spark.wrapped.tableNames("default").contains("t")) + assert(!spark.sqlContext.tableNames().contains("t")) + assert(!spark.sqlContext.tableNames("default").contains("t")) } - assert(!spark.wrapped.tableNames().contains("t")) - assert(!spark.wrapped.tableNames(db).contains("t")) + assert(!spark.sqlContext.tableNames().contains("t")) + assert(!spark.sqlContext.tableNames(db).contains("t")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 81f3ea8a6e80190aeabfc31153c2211e24db5788..8a31a49d97f085e7f4e45aa819dd90128ef462ab 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1417,7 +1417,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { """.stripMargin) checkAnswer( - spark.wrapped.tables().select('isTemporary).filter('tableName === "t2"), + spark.sqlContext.tables().select('isTemporary).filter('tableName === "t2"), Row(true) ) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index aba60da33fe3e9c38d735c4f2af3705018c1b995..bb351e20c5e99d4d991e4f458549213f76538854 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -61,7 +61,7 @@ private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { (data: Seq[T], tableName: String) (f: => Unit): Unit = { withOrcDataFrame(data) { df => - spark.wrapped.registerDataFrameAsTable(df, tableName) + spark.sqlContext.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } }