Skip to content
Snippets Groups Projects
Commit 0c6415ae authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-17822][R] Make JVMObjectTracker a member variable of RBackend


## What changes were proposed in this pull request?

* This PR changes `JVMObjectTracker` from `object` to `class` and let its instance associated with each RBackend. So we can manage the lifecycle of JVM objects when there are multiple `RBackend` sessions. `RBackend.close` will clear the object tracker explicitly.
* I assume that `SQLUtils` and `RRunner` do not need to track JVM instances, which could be wrong.
* Small refactor of `SerDe.sqlSerDe` to increase readability.

## How was this patch tested?

* Added unit tests for `JVMObjectTracker`.
* Wait for Jenkins to run full tests.

Author: Xiangrui Meng <meng@databricks.com>

Closes #16154 from mengxr/SPARK-17822.

(cherry picked from commit fd48d80a)
Signed-off-by: default avatarXiangrui Meng <meng@databricks.com>
parent b226f10e
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.r
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.ConcurrentHashMap
/** JVM object ID wrapper */
private[r] case class JVMObjectId(id: String) {
require(id != null, "Object ID cannot be null.")
}
/**
* Counter that tracks JVM objects returned to R.
* This is useful for referencing these objects in RPC calls.
*/
private[r] class JVMObjectTracker {
private[this] val objMap = new ConcurrentHashMap[JVMObjectId, Object]()
private[this] val objCounter = new AtomicInteger()
/**
* Returns the JVM object associated with the input key or None if not found.
*/
final def get(id: JVMObjectId): Option[Object] = this.synchronized {
if (objMap.containsKey(id)) {
Some(objMap.get(id))
} else {
None
}
}
/**
* Returns the JVM object associated with the input key or throws an exception if not found.
*/
@throws[NoSuchElementException]("if key does not exist.")
final def apply(id: JVMObjectId): Object = {
get(id).getOrElse(
throw new NoSuchElementException(s"$id does not exist.")
)
}
/**
* Adds a JVM object to track and returns assigned ID, which is unique within this tracker.
*/
final def addAndGetId(obj: Object): JVMObjectId = {
val id = JVMObjectId(objCounter.getAndIncrement().toString)
objMap.put(id, obj)
id
}
/**
* Removes and returns a JVM object with the specific ID from the tracker, or None if not found.
*/
final def remove(id: JVMObjectId): Option[Object] = this.synchronized {
if (objMap.containsKey(id)) {
Some(objMap.remove(id))
} else {
None
}
}
/**
* Number of JVM objects being tracked.
*/
final def size: Int = objMap.size()
/**
* Clears the tracker.
*/
final def clear(): Unit = objMap.clear()
}
...@@ -22,7 +22,7 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket} ...@@ -22,7 +22,7 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket}
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import io.netty.bootstrap.ServerBootstrap import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup} import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel
...@@ -42,6 +42,9 @@ private[spark] class RBackend { ...@@ -42,6 +42,9 @@ private[spark] class RBackend {
private[this] var bootstrap: ServerBootstrap = null private[this] var bootstrap: ServerBootstrap = null
private[this] var bossGroup: EventLoopGroup = null private[this] var bossGroup: EventLoopGroup = null
/** Tracks JVM objects returned to R for this RBackend instance. */
private[r] val jvmObjectTracker = new JVMObjectTracker
def init(): Int = { def init(): Int = {
val conf = new SparkConf() val conf = new SparkConf()
val backendConnectionTimeout = conf.getInt( val backendConnectionTimeout = conf.getInt(
...@@ -94,6 +97,7 @@ private[spark] class RBackend { ...@@ -94,6 +97,7 @@ private[spark] class RBackend {
bootstrap.childGroup().shutdownGracefully() bootstrap.childGroup().shutdownGracefully()
} }
bootstrap = null bootstrap = null
jvmObjectTracker.clear()
} }
} }
......
...@@ -20,7 +20,6 @@ package org.apache.spark.api.r ...@@ -20,7 +20,6 @@ package org.apache.spark.api.r
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import scala.collection.mutable.HashMap
import scala.language.existentials import scala.language.existentials
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
...@@ -62,7 +61,7 @@ private[r] class RBackendHandler(server: RBackend) ...@@ -62,7 +61,7 @@ private[r] class RBackendHandler(server: RBackend)
assert(numArgs == 1) assert(numArgs == 1)
writeInt(dos, 0) writeInt(dos, 0)
writeObject(dos, args(0)) writeObject(dos, args(0), server.jvmObjectTracker)
case "stopBackend" => case "stopBackend" =>
writeInt(dos, 0) writeInt(dos, 0)
writeType(dos, "void") writeType(dos, "void")
...@@ -72,9 +71,9 @@ private[r] class RBackendHandler(server: RBackend) ...@@ -72,9 +71,9 @@ private[r] class RBackendHandler(server: RBackend)
val t = readObjectType(dis) val t = readObjectType(dis)
assert(t == 'c') assert(t == 'c')
val objToRemove = readString(dis) val objToRemove = readString(dis)
JVMObjectTracker.remove(objToRemove) server.jvmObjectTracker.remove(JVMObjectId(objToRemove))
writeInt(dos, 0) writeInt(dos, 0)
writeObject(dos, null) writeObject(dos, null, server.jvmObjectTracker)
} catch { } catch {
case e: Exception => case e: Exception =>
logError(s"Removing $objId failed", e) logError(s"Removing $objId failed", e)
...@@ -143,12 +142,8 @@ private[r] class RBackendHandler(server: RBackend) ...@@ -143,12 +142,8 @@ private[r] class RBackendHandler(server: RBackend)
val cls = if (isStatic) { val cls = if (isStatic) {
Utils.classForName(objId) Utils.classForName(objId)
} else { } else {
JVMObjectTracker.get(objId) match { obj = server.jvmObjectTracker(JVMObjectId(objId))
case None => throw new IllegalArgumentException("Object not found " + objId) obj.getClass
case Some(o) =>
obj = o
o.getClass
}
} }
val args = readArgs(numArgs, dis) val args = readArgs(numArgs, dis)
...@@ -173,7 +168,7 @@ private[r] class RBackendHandler(server: RBackend) ...@@ -173,7 +168,7 @@ private[r] class RBackendHandler(server: RBackend)
// Write status bit // Write status bit
writeInt(dos, 0) writeInt(dos, 0)
writeObject(dos, ret.asInstanceOf[AnyRef]) writeObject(dos, ret.asInstanceOf[AnyRef], server.jvmObjectTracker)
} else if (methodName == "<init>") { } else if (methodName == "<init>") {
// methodName should be "<init>" for constructor // methodName should be "<init>" for constructor
val ctors = cls.getConstructors val ctors = cls.getConstructors
...@@ -193,7 +188,7 @@ private[r] class RBackendHandler(server: RBackend) ...@@ -193,7 +188,7 @@ private[r] class RBackendHandler(server: RBackend)
val obj = ctors(index.get).newInstance(args : _*) val obj = ctors(index.get).newInstance(args : _*)
writeInt(dos, 0) writeInt(dos, 0)
writeObject(dos, obj.asInstanceOf[AnyRef]) writeObject(dos, obj.asInstanceOf[AnyRef], server.jvmObjectTracker)
} else { } else {
throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId) throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId)
} }
...@@ -210,7 +205,7 @@ private[r] class RBackendHandler(server: RBackend) ...@@ -210,7 +205,7 @@ private[r] class RBackendHandler(server: RBackend)
// Read a number of arguments from the data input stream // Read a number of arguments from the data input stream
def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = {
(0 until numArgs).map { _ => (0 until numArgs).map { _ =>
readObject(dis) readObject(dis, server.jvmObjectTracker)
}.toArray }.toArray
} }
...@@ -286,37 +281,4 @@ private[r] class RBackendHandler(server: RBackend) ...@@ -286,37 +281,4 @@ private[r] class RBackendHandler(server: RBackend)
} }
} }
/**
* Helper singleton that tracks JVM objects returned to R.
* This is useful for referencing these objects in RPC calls.
*/
private[r] object JVMObjectTracker {
// TODO: This map should be thread-safe if we want to support multiple
// connections at the same time
private[this] val objMap = new HashMap[String, Object]
// TODO: We support only one connection now, so an integer is fine.
// Investigate using use atomic integer in the future.
private[this] var objCounter: Int = 0
def getObject(id: String): Object = {
objMap(id)
}
def get(id: String): Option[Object] = {
objMap.get(id)
}
def put(obj: Object): String = {
val objId = objCounter.toString
objCounter = objCounter + 1
objMap.put(objId, obj)
objId
}
def remove(id: String): Option[Object] = {
objMap.remove(id)
}
}
...@@ -152,7 +152,7 @@ private[spark] class RRunner[U]( ...@@ -152,7 +152,7 @@ private[spark] class RRunner[U](
dataOut.writeInt(mode) dataOut.writeInt(mode)
if (isDataFrame) { if (isDataFrame) {
SerDe.writeObject(dataOut, colNames) SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null)
} }
if (!iter.hasNext) { if (!iter.hasNext) {
......
...@@ -28,13 +28,20 @@ import scala.collection.mutable.WrappedArray ...@@ -28,13 +28,20 @@ import scala.collection.mutable.WrappedArray
* Utility functions to serialize, deserialize objects to / from R * Utility functions to serialize, deserialize objects to / from R
*/ */
private[spark] object SerDe { private[spark] object SerDe {
type ReadObject = (DataInputStream, Char) => Object type SQLReadObject = (DataInputStream, Char) => Object
type WriteObject = (DataOutputStream, Object) => Boolean type SQLWriteObject = (DataOutputStream, Object) => Boolean
var sqlSerDe: (ReadObject, WriteObject) = _ private[this] var sqlReadObject: SQLReadObject = _
private[this] var sqlWriteObject: SQLWriteObject = _
def registerSqlSerDe(sqlSerDe: (ReadObject, WriteObject)): Unit = { def setSQLReadObject(value: SQLReadObject): this.type = {
this.sqlSerDe = sqlSerDe sqlReadObject = value
this
}
def setSQLWriteObject(value: SQLWriteObject): this.type = {
sqlWriteObject = value
this
} }
// Type mapping from R to Java // Type mapping from R to Java
...@@ -56,32 +63,33 @@ private[spark] object SerDe { ...@@ -56,32 +63,33 @@ private[spark] object SerDe {
dis.readByte().toChar dis.readByte().toChar
} }
def readObject(dis: DataInputStream): Object = { def readObject(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Object = {
val dataType = readObjectType(dis) val dataType = readObjectType(dis)
readTypedObject(dis, dataType) readTypedObject(dis, dataType, jvmObjectTracker)
} }
def readTypedObject( def readTypedObject(
dis: DataInputStream, dis: DataInputStream,
dataType: Char): Object = { dataType: Char,
jvmObjectTracker: JVMObjectTracker): Object = {
dataType match { dataType match {
case 'n' => null case 'n' => null
case 'i' => new java.lang.Integer(readInt(dis)) case 'i' => new java.lang.Integer(readInt(dis))
case 'd' => new java.lang.Double(readDouble(dis)) case 'd' => new java.lang.Double(readDouble(dis))
case 'b' => new java.lang.Boolean(readBoolean(dis)) case 'b' => new java.lang.Boolean(readBoolean(dis))
case 'c' => readString(dis) case 'c' => readString(dis)
case 'e' => readMap(dis) case 'e' => readMap(dis, jvmObjectTracker)
case 'r' => readBytes(dis) case 'r' => readBytes(dis)
case 'a' => readArray(dis) case 'a' => readArray(dis, jvmObjectTracker)
case 'l' => readList(dis) case 'l' => readList(dis, jvmObjectTracker)
case 'D' => readDate(dis) case 'D' => readDate(dis)
case 't' => readTime(dis) case 't' => readTime(dis)
case 'j' => JVMObjectTracker.getObject(readString(dis)) case 'j' => jvmObjectTracker(JVMObjectId(readString(dis)))
case _ => case _ =>
if (sqlSerDe == null || sqlSerDe._1 == null) { if (sqlReadObject == null) {
throw new IllegalArgumentException (s"Invalid type $dataType") throw new IllegalArgumentException (s"Invalid type $dataType")
} else { } else {
val obj = (sqlSerDe._1)(dis, dataType) val obj = sqlReadObject(dis, dataType)
if (obj == null) { if (obj == null) {
throw new IllegalArgumentException (s"Invalid type $dataType") throw new IllegalArgumentException (s"Invalid type $dataType")
} else { } else {
...@@ -181,28 +189,28 @@ private[spark] object SerDe { ...@@ -181,28 +189,28 @@ private[spark] object SerDe {
} }
// All elements of an array must be of the same type // All elements of an array must be of the same type
def readArray(dis: DataInputStream): Array[_] = { def readArray(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Array[_] = {
val arrType = readObjectType(dis) val arrType = readObjectType(dis)
arrType match { arrType match {
case 'i' => readIntArr(dis) case 'i' => readIntArr(dis)
case 'c' => readStringArr(dis) case 'c' => readStringArr(dis)
case 'd' => readDoubleArr(dis) case 'd' => readDoubleArr(dis)
case 'b' => readBooleanArr(dis) case 'b' => readBooleanArr(dis)
case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) case 'j' => readStringArr(dis).map(x => jvmObjectTracker(JVMObjectId(x)))
case 'r' => readBytesArr(dis) case 'r' => readBytesArr(dis)
case 'a' => case 'a' =>
val len = readInt(dis) val len = readInt(dis)
(0 until len).map(_ => readArray(dis)).toArray (0 until len).map(_ => readArray(dis, jvmObjectTracker)).toArray
case 'l' => case 'l' =>
val len = readInt(dis) val len = readInt(dis)
(0 until len).map(_ => readList(dis)).toArray (0 until len).map(_ => readList(dis, jvmObjectTracker)).toArray
case _ => case _ =>
if (sqlSerDe == null || sqlSerDe._1 == null) { if (sqlReadObject == null) {
throw new IllegalArgumentException (s"Invalid array type $arrType") throw new IllegalArgumentException (s"Invalid array type $arrType")
} else { } else {
val len = readInt(dis) val len = readInt(dis)
(0 until len).map { _ => (0 until len).map { _ =>
val obj = (sqlSerDe._1)(dis, arrType) val obj = sqlReadObject(dis, arrType)
if (obj == null) { if (obj == null) {
throw new IllegalArgumentException (s"Invalid array type $arrType") throw new IllegalArgumentException (s"Invalid array type $arrType")
} else { } else {
...@@ -215,17 +223,19 @@ private[spark] object SerDe { ...@@ -215,17 +223,19 @@ private[spark] object SerDe {
// Each element of a list can be of different type. They are all represented // Each element of a list can be of different type. They are all represented
// as Object on JVM side // as Object on JVM side
def readList(dis: DataInputStream): Array[Object] = { def readList(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Array[Object] = {
val len = readInt(dis) val len = readInt(dis)
(0 until len).map(_ => readObject(dis)).toArray (0 until len).map(_ => readObject(dis, jvmObjectTracker)).toArray
} }
def readMap(in: DataInputStream): java.util.Map[Object, Object] = { def readMap(
in: DataInputStream,
jvmObjectTracker: JVMObjectTracker): java.util.Map[Object, Object] = {
val len = readInt(in) val len = readInt(in)
if (len > 0) { if (len > 0) {
// Keys is an array of String // Keys is an array of String
val keys = readArray(in).asInstanceOf[Array[Object]] val keys = readArray(in, jvmObjectTracker).asInstanceOf[Array[Object]]
val values = readList(in) val values = readList(in, jvmObjectTracker)
keys.zip(values).toMap.asJava keys.zip(values).toMap.asJava
} else { } else {
...@@ -272,7 +282,11 @@ private[spark] object SerDe { ...@@ -272,7 +282,11 @@ private[spark] object SerDe {
} }
} }
private def writeKeyValue(dos: DataOutputStream, key: Object, value: Object): Unit = { private def writeKeyValue(
dos: DataOutputStream,
key: Object,
value: Object,
jvmObjectTracker: JVMObjectTracker): Unit = {
if (key == null) { if (key == null) {
throw new IllegalArgumentException("Key in map can't be null.") throw new IllegalArgumentException("Key in map can't be null.")
} else if (!key.isInstanceOf[String]) { } else if (!key.isInstanceOf[String]) {
...@@ -280,10 +294,10 @@ private[spark] object SerDe { ...@@ -280,10 +294,10 @@ private[spark] object SerDe {
} }
writeString(dos, key.asInstanceOf[String]) writeString(dos, key.asInstanceOf[String])
writeObject(dos, value) writeObject(dos, value, jvmObjectTracker)
} }
def writeObject(dos: DataOutputStream, obj: Object): Unit = { def writeObject(dos: DataOutputStream, obj: Object, jvmObjectTracker: JVMObjectTracker): Unit = {
if (obj == null) { if (obj == null) {
writeType(dos, "void") writeType(dos, "void")
} else { } else {
...@@ -373,14 +387,14 @@ private[spark] object SerDe { ...@@ -373,14 +387,14 @@ private[spark] object SerDe {
case v: Array[Object] => case v: Array[Object] =>
writeType(dos, "list") writeType(dos, "list")
writeInt(dos, v.length) writeInt(dos, v.length)
v.foreach(elem => writeObject(dos, elem)) v.foreach(elem => writeObject(dos, elem, jvmObjectTracker))
// Handle Properties // Handle Properties
// This must be above the case java.util.Map below. // This must be above the case java.util.Map below.
// (Properties implements Map<Object,Object> and will be serialized as map otherwise) // (Properties implements Map<Object,Object> and will be serialized as map otherwise)
case v: java.util.Properties => case v: java.util.Properties =>
writeType(dos, "jobj") writeType(dos, "jobj")
writeJObj(dos, value) writeJObj(dos, value, jvmObjectTracker)
// Handle map // Handle map
case v: java.util.Map[_, _] => case v: java.util.Map[_, _] =>
...@@ -392,19 +406,21 @@ private[spark] object SerDe { ...@@ -392,19 +406,21 @@ private[spark] object SerDe {
val key = entry.getKey val key = entry.getKey
val value = entry.getValue val value = entry.getValue
writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) writeKeyValue(
dos, key.asInstanceOf[Object], value.asInstanceOf[Object], jvmObjectTracker)
} }
case v: scala.collection.Map[_, _] => case v: scala.collection.Map[_, _] =>
writeType(dos, "map") writeType(dos, "map")
writeInt(dos, v.size) writeInt(dos, v.size)
v.foreach { case (key, value) => v.foreach { case (k1, v1) =>
writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) writeKeyValue(dos, k1.asInstanceOf[Object], v1.asInstanceOf[Object], jvmObjectTracker)
} }
case _ => case _ =>
if (sqlSerDe == null || sqlSerDe._2 == null || !(sqlSerDe._2)(dos, value)) { val sqlWriteSucceeded = sqlWriteObject != null && sqlWriteObject(dos, value)
if (!sqlWriteSucceeded) {
writeType(dos, "jobj") writeType(dos, "jobj")
writeJObj(dos, value) writeJObj(dos, value, jvmObjectTracker)
} }
} }
} }
...@@ -447,9 +463,9 @@ private[spark] object SerDe { ...@@ -447,9 +463,9 @@ private[spark] object SerDe {
out.write(value) out.write(value)
} }
def writeJObj(out: DataOutputStream, value: Object): Unit = { def writeJObj(out: DataOutputStream, value: Object, jvmObjectTracker: JVMObjectTracker): Unit = {
val objId = JVMObjectTracker.put(value) val JVMObjectId(id) = jvmObjectTracker.addAndGetId(value)
writeString(out, objId) writeString(out, id)
} }
def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.r
import org.apache.spark.SparkFunSuite
class JVMObjectTrackerSuite extends SparkFunSuite {
test("JVMObjectId does not take null IDs") {
intercept[IllegalArgumentException] {
JVMObjectId(null)
}
}
test("JVMObjectTracker") {
val tracker = new JVMObjectTracker
assert(tracker.size === 0)
withClue("an empty tracker can be cleared") {
tracker.clear()
}
val none = JVMObjectId("none")
assert(tracker.get(none) === None)
intercept[NoSuchElementException] {
tracker(JVMObjectId("none"))
}
val obj1 = new Object
val id1 = tracker.addAndGetId(obj1)
assert(id1 != null)
assert(tracker.size === 1)
assert(tracker.get(id1).get.eq(obj1))
assert(tracker(id1).eq(obj1))
val obj2 = new Object
val id2 = tracker.addAndGetId(obj2)
assert(id1 !== id2)
assert(tracker.size === 2)
assert(tracker(id2).eq(obj2))
val Some(obj1Removed) = tracker.remove(id1)
assert(obj1Removed.eq(obj1))
assert(tracker.get(id1) === None)
assert(tracker.size === 1)
assert(tracker(id2).eq(obj2))
val obj3 = new Object
val id3 = tracker.addAndGetId(obj3)
assert(tracker.size === 2)
assert(id3 != id1)
assert(id3 != id2)
assert(tracker(id3).eq(obj3))
tracker.clear()
assert(tracker.size === 0)
assert(tracker.get(id1) === None)
assert(tracker.get(id2) === None)
assert(tracker.get(id3) === None)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.r
import org.apache.spark.SparkFunSuite
class RBackendSuite extends SparkFunSuite {
test("close() clears jvmObjectTracker") {
val backend = new RBackend
val tracker = backend.jvmObjectTracker
val id = tracker.addAndGetId(new Object)
backend.close()
assert(tracker.get(id) === None)
assert(tracker.size === 0)
}
}
...@@ -36,7 +36,7 @@ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION ...@@ -36,7 +36,7 @@ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
private[sql] object SQLUtils extends Logging { private[sql] object SQLUtils extends Logging {
SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) SerDe.setSQLReadObject(readSqlObject).setSQLWriteObject(writeSqlObject)
private[this] def withHiveExternalCatalog(sc: SparkContext): SparkContext = { private[this] def withHiveExternalCatalog(sc: SparkContext): SparkContext = {
sc.conf.set(CATALOG_IMPLEMENTATION.key, "hive") sc.conf.set(CATALOG_IMPLEMENTATION.key, "hive")
...@@ -158,7 +158,7 @@ private[sql] object SQLUtils extends Logging { ...@@ -158,7 +158,7 @@ private[sql] object SQLUtils extends Logging {
val dis = new DataInputStream(bis) val dis = new DataInputStream(bis)
val num = SerDe.readInt(dis) val num = SerDe.readInt(dis)
Row.fromSeq((0 until num).map { i => Row.fromSeq((0 until num).map { i =>
doConversion(SerDe.readObject(dis), schema.fields(i).dataType) doConversion(SerDe.readObject(dis, jvmObjectTracker = null), schema.fields(i).dataType)
}) })
} }
...@@ -167,7 +167,7 @@ private[sql] object SQLUtils extends Logging { ...@@ -167,7 +167,7 @@ private[sql] object SQLUtils extends Logging {
val dos = new DataOutputStream(bos) val dos = new DataOutputStream(bos)
val cols = (0 until row.length).map(row(_).asInstanceOf[Object]).toArray val cols = (0 until row.length).map(row(_).asInstanceOf[Object]).toArray
SerDe.writeObject(dos, cols) SerDe.writeObject(dos, cols, jvmObjectTracker = null)
bos.toByteArray() bos.toByteArray()
} }
...@@ -247,7 +247,7 @@ private[sql] object SQLUtils extends Logging { ...@@ -247,7 +247,7 @@ private[sql] object SQLUtils extends Logging {
dataType match { dataType match {
case 's' => case 's' =>
// Read StructType for DataFrame // Read StructType for DataFrame
val fields = SerDe.readList(dis).asInstanceOf[Array[Object]] val fields = SerDe.readList(dis, jvmObjectTracker = null).asInstanceOf[Array[Object]]
Row.fromSeq(fields) Row.fromSeq(fields)
case _ => null case _ => null
} }
...@@ -258,8 +258,8 @@ private[sql] object SQLUtils extends Logging { ...@@ -258,8 +258,8 @@ private[sql] object SQLUtils extends Logging {
// Handle struct type in DataFrame // Handle struct type in DataFrame
case v: GenericRowWithSchema => case v: GenericRowWithSchema =>
dos.writeByte('s') dos.writeByte('s')
SerDe.writeObject(dos, v.schema.fieldNames) SerDe.writeObject(dos, v.schema.fieldNames, jvmObjectTracker = null)
SerDe.writeObject(dos, v.values) SerDe.writeObject(dos, v.values, jvmObjectTracker = null)
true true
case _ => case _ =>
false false
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment