From f5486e9f75d62919583da5ecf9a9ad00222b2227 Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@apache.org> Date: Sun, 16 Mar 2014 09:57:21 -0700 Subject: [PATCH] SPARK-1255: Allow user to pass Serializer object instead of class name for shuffle. This is more general than simply passing a string name and leaves more room for performance optimizations. Note that this is technically an API breaking change in the following two ways: 1. The shuffle serializer specification in ShuffleDependency now require an object instead of a String (of the class name), but I suspect nobody else in this world has used this API other than me in GraphX and Shark. 2. Serializer's in Spark from now on are required to be serializable. Author: Reynold Xin <rxin@apache.org> Closes #149 from rxin/serializer and squashes the following commits: 5acaccd [Reynold Xin] Properly call serializer's constructors. 2a8d75a [Reynold Xin] Added more documentation for the serializer option in ShuffleDependency. 7420185 [Reynold Xin] Allow user to pass Serializer object instead of class name for shuffle. --- .../scala/org/apache/spark/Dependency.scala | 6 +- .../org/apache/spark/ShuffleFetcher.scala | 2 +- .../scala/org/apache/spark/SparkEnv.scala | 24 +++--- .../org/apache/spark/rdd/CoGroupedRDD.scala | 18 ++--- .../apache/spark/rdd/PairRDDFunctions.scala | 7 +- .../org/apache/spark/rdd/ShuffledRDD.scala | 13 ++-- .../org/apache/spark/rdd/SubtractedRDD.scala | 20 ++--- .../spark/scheduler/ShuffleMapTask.scala | 3 +- .../spark/serializer/JavaSerializer.scala | 27 ++++--- .../spark/serializer/KryoSerializer.scala | 16 ++-- .../apache/spark/serializer/Serializer.scala | 16 +++- .../spark/serializer/SerializerManager.scala | 75 ------------------- .../collection/ExternalAppendOnlyMap.scala | 2 +- .../scala/org/apache/spark/ShuffleSuite.scala | 9 ++- .../apache/spark/graphx/impl/GraphImpl.scala | 2 +- .../graphx/impl/MessageToPartition.scala | 12 +-- .../spark/graphx/impl/Serializers.scala | 14 ++-- .../apache/spark/graphx/SerializerSuite.scala | 30 +++----- 18 files changed, 125 insertions(+), 171 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index cc30105940..448f87b81e 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer /** * Base class for dependencies. @@ -43,12 +44,13 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { * Represents a dependency on the output of a shuffle stage. * @param rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output - * @param serializerClass class name of the serializer to use + * @param serializer [[Serializer]] to use. If set to null, the default serializer, as specified + * by `spark.serializer` config option, will be used. */ class ShuffleDependency[K, V]( @transient rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, - val serializerClass: String = null) + val serializer: Serializer = null) extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala index e8f756c408..a4f69b6b22 100644 --- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala @@ -29,7 +29,7 @@ private[spark] abstract class ShuffleFetcher { shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T] + serializer: Serializer = SparkEnv.get.serializer): Iterator[T] /** Stop the fetcher */ def stop() {} diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 5e43b51984..d035d909b7 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -28,7 +28,7 @@ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.storage.{BlockManager, BlockManagerMaster, BlockManagerMasterActor} import org.apache.spark.network.ConnectionManager -import org.apache.spark.serializer.{Serializer, SerializerManager} +import org.apache.spark.serializer.Serializer import org.apache.spark.util.{AkkaUtils, Utils} /** @@ -41,7 +41,6 @@ import org.apache.spark.util.{AkkaUtils, Utils} class SparkEnv private[spark] ( val executorId: String, val actorSystem: ActorSystem, - val serializerManager: SerializerManager, val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -139,16 +138,22 @@ object SparkEnv extends Logging { // defaultClassName if the property is not set, and return it as a T def instantiateClass[T](propertyName: String, defaultClassName: String): T = { val name = conf.get(propertyName, defaultClassName) - Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] + val cls = Class.forName(name, true, classLoader) + // First try with the constructor that takes SparkConf. If we can't find one, + // use a no-arg constructor instead. + try { + cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] + } catch { + case _: NoSuchMethodException => + cls.getConstructor().newInstance().asInstanceOf[T] + } } - val serializerManager = new SerializerManager - val serializer = serializerManager.setDefault( - conf.get("spark.serializer", "org.apache.spark.serializer.JavaSerializer"), conf) + val serializer = instantiateClass[Serializer]( + "spark.serializer", "org.apache.spark.serializer.JavaSerializer") - val closureSerializer = serializerManager.get( - conf.get("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer"), - conf) + val closureSerializer = instantiateClass[Serializer]( + "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") def registerOrLookup(name: String, newActor: => Actor): ActorRef = { if (isDriver) { @@ -220,7 +225,6 @@ object SparkEnv extends Logging { new SparkEnv( executorId, actorSystem, - serializerManager, serializer, closureSerializer, cacheManager, diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 699a10c96c..8561711931 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap} +import org.apache.spark.serializer.Serializer private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -66,10 +67,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: private type CoGroupValue = (Any, Int) // Int is dependency number private type CoGroupCombiner = Seq[CoGroup] - private var serializerClass: String = null + private var serializer: Serializer = null - def setSerializer(cls: String): CoGroupedRDD[K] = { - serializerClass = cls + def setSerializer(serializer: Serializer): CoGroupedRDD[K] = { + this.serializer = serializer this } @@ -80,7 +81,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency[Any, Any](rdd, part, serializerClass) + new ShuffleDependency[Any, Any](rdd, part, serializer) } } } @@ -113,18 +114,17 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { + case NarrowCoGroupSplitDep(rdd, _, itsSplit) => // Read them from the parent val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] rddIterators += ((it, depNum)) - } - case ShuffleCoGroupSplitDep(shuffleId) => { + + case ShuffleCoGroupSplitDep(shuffleId) => // Read map outputs of shuffle val fetcher = SparkEnv.get.shuffleFetcher - val ser = SparkEnv.get.serializerManager.get(serializerClass, sparkConf) + val ser = Serializer.getSerializer(serializer) val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser) rddIterators += ((it, depNum)) - } } if (!externalSorting) { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index b20ed99f89..b0d322fe27 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -44,6 +44,7 @@ import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.partial.{BoundedDouble, PartialResult} +import org.apache.spark.serializer.Serializer import org.apache.spark.util.SerializableHyperLogLog /** @@ -73,7 +74,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) mergeCombiners: (C, C) => C, partitioner: Partitioner, mapSideCombine: Boolean = true, - serializerClass: String = null): RDD[(K, C)] = { + serializer: Serializer = null): RDD[(K, C)] = { require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0 if (getKeyClass().isArray) { if (mapSideCombine) { @@ -93,13 +94,13 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) aggregator.combineValuesByKey(iter, context) }, preservesPartitioning = true) val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) - .setSerializer(serializerClass) + .setSerializer(serializer) partitioned.mapPartitionsWithContext((context, iter) => { new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter, context)) }, preservesPartitioning = true) } else { // Don't apply map-side combiner. - val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) + val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializer) values.mapPartitionsWithContext((context, iter) => { new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) }, preservesPartitioning = true) diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 0bbda25a90..02660ea6a4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext} +import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { override val index = idx @@ -38,15 +39,15 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( part: Partitioner) extends RDD[P](prev.context, Nil) { - private var serializerClass: String = null + private var serializer: Serializer = null - def setSerializer(cls: String): ShuffledRDD[K, V, P] = { - serializerClass = cls + def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = { + this.serializer = serializer this } override def getDependencies: Seq[Dependency[_]] = { - List(new ShuffleDependency(prev, part, serializerClass)) + List(new ShuffleDependency(prev, part, serializer)) } override val partitioner = Some(part) @@ -57,8 +58,8 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[P] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, - SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)) + val ser = Serializer.getSerializer(serializer) + SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser) } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 5fe9f363db..9a09c05bbc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -30,6 +30,7 @@ import org.apache.spark.Partitioner import org.apache.spark.ShuffleDependency import org.apache.spark.SparkEnv import org.apache.spark.TaskContext +import org.apache.spark.serializer.Serializer /** * An optimized version of cogroup for set difference/subtraction. @@ -53,10 +54,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) { - private var serializerClass: String = null + private var serializer: Serializer = null - def setSerializer(cls: String): SubtractedRDD[K, V, W] = { - serializerClass = cls + def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = { + this.serializer = serializer this } @@ -67,7 +68,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency(rdd, part, serializerClass) + new ShuffleDependency(rdd, part, serializer) } } } @@ -92,7 +93,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { val partition = p.asInstanceOf[CoGroupPartition] - val serializer = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf) + val ser = Serializer.getSerializer(serializer) val map = new JHashMap[K, ArrayBuffer[V]] def getSeq(k: K): ArrayBuffer[V] = { val seq = map.get(k) @@ -105,14 +106,13 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( } } def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit) = dep match { - case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { + case NarrowCoGroupSplitDep(rdd, _, itsSplit) => rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) - } - case ShuffleCoGroupSplitDep(shuffleId) => { + + case ShuffleCoGroupSplitDep(shuffleId) => val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index, - context, serializer) + context, ser) iter.foreach(op) - } } // the first dep is rdd1; add all values to the map integrate(partition.deps(0), t => getSeq(t._1) += t._2) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 77789031f4..2a9edf4a76 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -26,6 +26,7 @@ import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDDCheckpointData +import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} @@ -153,7 +154,7 @@ private[spark] class ShuffleMapTask( try { // Obtain all the block writers for shuffle blocks. - val ser = SparkEnv.get.serializerManager.get(dep.serializerClass, SparkEnv.get.conf) + val ser = Serializer.getSerializer(dep.serializer) shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser) // Write the map output to its associated buckets. diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index bfa647f7f0..18a68b05fa 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -23,11 +23,10 @@ import java.nio.ByteBuffer import org.apache.spark.SparkConf import org.apache.spark.util.ByteBufferInputStream -private[spark] class JavaSerializationStream(out: OutputStream, conf: SparkConf) +private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int) extends SerializationStream { - val objOut = new ObjectOutputStream(out) - var counter = 0 - val counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000) + private val objOut = new ObjectOutputStream(out) + private var counter = 0 /** * Calling reset to avoid memory leak: @@ -51,7 +50,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, conf: SparkConf) private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader) extends DeserializationStream { - val objIn = new ObjectInputStream(in) { + private val objIn = new ObjectInputStream(in) { override def resolveClass(desc: ObjectStreamClass) = Class.forName(desc.getName, false, loader) } @@ -60,7 +59,7 @@ extends DeserializationStream { def close() { objIn.close() } } -private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerInstance { +private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance { def serialize[T](t: T): ByteBuffer = { val bos = new ByteArrayOutputStream() val out = serializeStream(bos) @@ -82,7 +81,7 @@ private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerI } def serializeStream(s: OutputStream): SerializationStream = { - new JavaSerializationStream(s, conf) + new JavaSerializationStream(s, counterReset) } def deserializeStream(s: InputStream): DeserializationStream = { @@ -97,6 +96,16 @@ private[spark] class JavaSerializerInstance(conf: SparkConf) extends SerializerI /** * A Spark serializer that uses Java's built-in serialization. */ -class JavaSerializer(conf: SparkConf) extends Serializer { - def newInstance(): SerializerInstance = new JavaSerializerInstance(conf) +class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { + private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000) + + def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset) + + override def writeExternal(out: ObjectOutput) { + out.writeInt(counterReset) + } + + override def readExternal(in: ObjectInput) { + counterReset = in.readInt() + } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 920490f9d0..6b6d814c1f 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -34,10 +34,14 @@ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. */ -class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serializer with Logging { - private val bufferSize = { - conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024 - } +class KryoSerializer(conf: SparkConf) + extends org.apache.spark.serializer.Serializer + with Logging + with Serializable { + + private val bufferSize = conf.getInt("spark.kryoserializer.buffer.mb", 2) * 1024 * 1024 + private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) + private val registrator = conf.getOption("spark.kryo.registrator") def newKryoOutput() = new KryoOutput(bufferSize) @@ -48,7 +52,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. // Do this before we invoke the user registrator so the user registrator can override this. - kryo.setReferences(conf.getBoolean("spark.kryo.referenceTracking", true)) + kryo.setReferences(referenceTracking) for (cls <- KryoSerializer.toRegister) kryo.register(cls) @@ -58,7 +62,7 @@ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serial // Allow the user to register their own classes by setting spark.kryo.registrator try { - for (regCls <- conf.getOption("spark.kryo.registrator")) { + for (regCls <- registrator) { logDebug("Running user registrator: " + regCls) val reg = Class.forName(regCls, true, classLoader).newInstance() .asInstanceOf[KryoRegistrator] diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 16677ab54b..099143494b 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -23,21 +23,31 @@ import java.nio.ByteBuffer import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream import org.apache.spark.util.{ByteBufferInputStream, NextIterator} +import org.apache.spark.SparkEnv /** * A serializer. Because some serialization libraries are not thread safe, this class is used to * create [[org.apache.spark.serializer.SerializerInstance]] objects that do the actual * serialization and are guaranteed to only be called from one thread at a time. * - * Implementations of this trait should have a zero-arg constructor or a constructor that accepts a - * [[org.apache.spark.SparkConf]] as parameter. If both constructors are defined, the latter takes - * precedence. + * Implementations of this trait should implement: + * 1. a zero-arg constructor or a constructor that accepts a [[org.apache.spark.SparkConf]] + * as parameter. If both constructors are defined, the latter takes precedence. + * + * 2. Java serialization interface. */ trait Serializer { def newInstance(): SerializerInstance } +object Serializer { + def getSerializer(serializer: Serializer): Serializer = { + if (serializer == null) SparkEnv.get.serializer else serializer + } +} + + /** * An instance of a serializer, for use by one thread at a time. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala deleted file mode 100644 index 65ac0155f4..0000000000 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.serializer - -import java.util.concurrent.ConcurrentHashMap - -import org.apache.spark.SparkConf - -/** - * A service that returns a serializer object given the serializer's class name. If a previous - * instance of the serializer object has been created, the get method returns that instead of - * creating a new one. - */ -private[spark] class SerializerManager { - // TODO: Consider moving this into SparkConf itself to remove the global singleton. - - private val serializers = new ConcurrentHashMap[String, Serializer] - private var _default: Serializer = _ - - def default = _default - - def setDefault(clsName: String, conf: SparkConf): Serializer = { - _default = get(clsName, conf) - _default - } - - def get(clsName: String, conf: SparkConf): Serializer = { - if (clsName == null) { - default - } else { - var serializer = serializers.get(clsName) - if (serializer != null) { - // If the serializer has been created previously, reuse that. - serializer - } else this.synchronized { - // Otherwise, create a new one. But make sure no other thread has attempted - // to create another new one at the same time. - serializer = serializers.get(clsName) - if (serializer == null) { - val clsLoader = Thread.currentThread.getContextClassLoader - val cls = Class.forName(clsName, true, clsLoader) - - // First try with the constructor that takes SparkConf. If we can't find one, - // use a no-arg constructor instead. - try { - val constructor = cls.getConstructor(classOf[SparkConf]) - serializer = constructor.newInstance(conf).asInstanceOf[Serializer] - } catch { - case _: NoSuchMethodException => - val constructor = cls.getConstructor() - serializer = constructor.newInstance().asInstanceOf[Serializer] - } - - serializers.put(clsName, serializer) - } - serializer - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index ed74a31f05..caa06d5b44 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -60,7 +60,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, - serializer: Serializer = SparkEnv.get.serializerManager.default, + serializer: Serializer = SparkEnv.get.serializer, blockManager: BlockManager = SparkEnv.get.blockManager) extends Iterable[(K, C)] with Serializable with Logging { diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index abea36f7c8..be6508a40e 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -27,6 +27,9 @@ import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.MutablePair class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { + + val conf = new SparkConf(loadDefaults = false) + test("groupByKey without compression") { try { System.setProperty("spark.shuffle.compress", "false") @@ -54,7 +57,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { // If the Kryo serializer is not used correctly, the shuffle would fail because the // default Java serializer cannot handle the non serializable class. val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( - b, new HashPartitioner(NUM_BLOCKS)).setSerializer(classOf[KryoSerializer].getName) + b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId assert(c.count === 10) @@ -76,7 +79,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { // If the Kryo serializer is not used correctly, the shuffle would fail because the // default Java serializer cannot handle the non serializable class. val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( - b, new HashPartitioner(3)).setSerializer(classOf[KryoSerializer].getName) + b, new HashPartitioner(3)).setSerializer(new KryoSerializer(conf)) assert(c.count === 10) } @@ -92,7 +95,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { // NOTE: The default Java serializer doesn't create zero-sized blocks. // So, use Kryo val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) - .setSerializer(classOf[KryoSerializer].getName) + .setSerializer(new KryoSerializer(conf)) val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId assert(c.count === 4) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 1d029bf009..5e9be18990 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -391,6 +391,6 @@ object GraphImpl { // TODO: Consider doing map side distinct before shuffle. new ShuffledRDD[VertexId, Int, (VertexId, Int)]( edges.collectVertexIds.map(vid => (vid, 0)), partitioner) - .setSerializer(classOf[VertexIdMsgSerializer].getName) + .setSerializer(new VertexIdMsgSerializer) } } // end of object GraphImpl diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala index e9ee09c361..fe6fe76def 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala @@ -65,11 +65,11 @@ class VertexBroadcastMsgRDDFunctions[T: ClassTag](self: RDD[VertexBroadcastMsg[T // Set a custom serializer if the data is of int or double type. if (classTag[T] == ClassTag.Int) { - rdd.setSerializer(classOf[IntVertexBroadcastMsgSerializer].getName) + rdd.setSerializer(new IntVertexBroadcastMsgSerializer) } else if (classTag[T] == ClassTag.Long) { - rdd.setSerializer(classOf[LongVertexBroadcastMsgSerializer].getName) + rdd.setSerializer(new LongVertexBroadcastMsgSerializer) } else if (classTag[T] == ClassTag.Double) { - rdd.setSerializer(classOf[DoubleVertexBroadcastMsgSerializer].getName) + rdd.setSerializer(new DoubleVertexBroadcastMsgSerializer) } rdd } @@ -104,11 +104,11 @@ object MsgRDDFunctions { // Set a custom serializer if the data is of int or double type. if (classTag[T] == ClassTag.Int) { - rdd.setSerializer(classOf[IntAggMsgSerializer].getName) + rdd.setSerializer(new IntAggMsgSerializer) } else if (classTag[T] == ClassTag.Long) { - rdd.setSerializer(classOf[LongAggMsgSerializer].getName) + rdd.setSerializer(new LongAggMsgSerializer) } else if (classTag[T] == ClassTag.Double) { - rdd.setSerializer(classOf[DoubleAggMsgSerializer].getName) + rdd.setSerializer(new DoubleAggMsgSerializer) } rdd } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala index c74d487e20..34a145e018 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala @@ -25,7 +25,7 @@ import org.apache.spark.graphx._ import org.apache.spark.serializer._ private[graphx] -class VertexIdMsgSerializer(conf: SparkConf) extends Serializer { +class VertexIdMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { @@ -46,7 +46,7 @@ class VertexIdMsgSerializer(conf: SparkConf) extends Serializer { /** A special shuffle serializer for VertexBroadcastMessage[Int]. */ private[graphx] -class IntVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer { +class IntVertexBroadcastMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { @@ -70,7 +70,7 @@ class IntVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer { /** A special shuffle serializer for VertexBroadcastMessage[Long]. */ private[graphx] -class LongVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer { +class LongVertexBroadcastMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { @@ -94,7 +94,7 @@ class LongVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer { /** A special shuffle serializer for VertexBroadcastMessage[Double]. */ private[graphx] -class DoubleVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer { +class DoubleVertexBroadcastMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { @@ -118,7 +118,7 @@ class DoubleVertexBroadcastMsgSerializer(conf: SparkConf) extends Serializer { /** A special shuffle serializer for AggregationMessage[Int]. */ private[graphx] -class IntAggMsgSerializer(conf: SparkConf) extends Serializer { +class IntAggMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { @@ -142,7 +142,7 @@ class IntAggMsgSerializer(conf: SparkConf) extends Serializer { /** A special shuffle serializer for AggregationMessage[Long]. */ private[graphx] -class LongAggMsgSerializer(conf: SparkConf) extends Serializer { +class LongAggMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { @@ -166,7 +166,7 @@ class LongAggMsgSerializer(conf: SparkConf) extends Serializer { /** A special shuffle serializer for AggregationMessage[Double]. */ private[graphx] -class DoubleAggMsgSerializer(conf: SparkConf) extends Serializer { +class DoubleAggMsgSerializer extends Serializer with Serializable { override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala index e5a582b47b..73438d9535 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala @@ -32,15 +32,14 @@ import org.apache.spark.serializer.SerializationStream class SerializerSuite extends FunSuite with LocalSparkContext { test("IntVertexBroadcastMsgSerializer") { - val conf = new SparkConf(false) val outMsg = new VertexBroadcastMsg[Int](3, 4, 5) val bout = new ByteArrayOutputStream - val outStrm = new IntVertexBroadcastMsgSerializer(conf).newInstance().serializeStream(bout) + val outStrm = new IntVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) outStrm.writeObject(outMsg) outStrm.writeObject(outMsg) bout.flush() val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new IntVertexBroadcastMsgSerializer(conf).newInstance().deserializeStream(bin) + val inStrm = new IntVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) val inMsg1: VertexBroadcastMsg[Int] = inStrm.readObject() val inMsg2: VertexBroadcastMsg[Int] = inStrm.readObject() assert(outMsg.vid === inMsg1.vid) @@ -54,15 +53,14 @@ class SerializerSuite extends FunSuite with LocalSparkContext { } test("LongVertexBroadcastMsgSerializer") { - val conf = new SparkConf(false) val outMsg = new VertexBroadcastMsg[Long](3, 4, 5) val bout = new ByteArrayOutputStream - val outStrm = new LongVertexBroadcastMsgSerializer(conf).newInstance().serializeStream(bout) + val outStrm = new LongVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) outStrm.writeObject(outMsg) outStrm.writeObject(outMsg) bout.flush() val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new LongVertexBroadcastMsgSerializer(conf).newInstance().deserializeStream(bin) + val inStrm = new LongVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) val inMsg1: VertexBroadcastMsg[Long] = inStrm.readObject() val inMsg2: VertexBroadcastMsg[Long] = inStrm.readObject() assert(outMsg.vid === inMsg1.vid) @@ -76,15 +74,14 @@ class SerializerSuite extends FunSuite with LocalSparkContext { } test("DoubleVertexBroadcastMsgSerializer") { - val conf = new SparkConf(false) val outMsg = new VertexBroadcastMsg[Double](3, 4, 5.0) val bout = new ByteArrayOutputStream - val outStrm = new DoubleVertexBroadcastMsgSerializer(conf).newInstance().serializeStream(bout) + val outStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().serializeStream(bout) outStrm.writeObject(outMsg) outStrm.writeObject(outMsg) bout.flush() val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new DoubleVertexBroadcastMsgSerializer(conf).newInstance().deserializeStream(bin) + val inStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin) val inMsg1: VertexBroadcastMsg[Double] = inStrm.readObject() val inMsg2: VertexBroadcastMsg[Double] = inStrm.readObject() assert(outMsg.vid === inMsg1.vid) @@ -98,15 +95,14 @@ class SerializerSuite extends FunSuite with LocalSparkContext { } test("IntAggMsgSerializer") { - val conf = new SparkConf(false) val outMsg = (4: VertexId, 5) val bout = new ByteArrayOutputStream - val outStrm = new IntAggMsgSerializer(conf).newInstance().serializeStream(bout) + val outStrm = new IntAggMsgSerializer().newInstance().serializeStream(bout) outStrm.writeObject(outMsg) outStrm.writeObject(outMsg) bout.flush() val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new IntAggMsgSerializer(conf).newInstance().deserializeStream(bin) + val inStrm = new IntAggMsgSerializer().newInstance().deserializeStream(bin) val inMsg1: (VertexId, Int) = inStrm.readObject() val inMsg2: (VertexId, Int) = inStrm.readObject() assert(outMsg === inMsg1) @@ -118,15 +114,14 @@ class SerializerSuite extends FunSuite with LocalSparkContext { } test("LongAggMsgSerializer") { - val conf = new SparkConf(false) val outMsg = (4: VertexId, 1L << 32) val bout = new ByteArrayOutputStream - val outStrm = new LongAggMsgSerializer(conf).newInstance().serializeStream(bout) + val outStrm = new LongAggMsgSerializer().newInstance().serializeStream(bout) outStrm.writeObject(outMsg) outStrm.writeObject(outMsg) bout.flush() val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new LongAggMsgSerializer(conf).newInstance().deserializeStream(bin) + val inStrm = new LongAggMsgSerializer().newInstance().deserializeStream(bin) val inMsg1: (VertexId, Long) = inStrm.readObject() val inMsg2: (VertexId, Long) = inStrm.readObject() assert(outMsg === inMsg1) @@ -138,15 +133,14 @@ class SerializerSuite extends FunSuite with LocalSparkContext { } test("DoubleAggMsgSerializer") { - val conf = new SparkConf(false) val outMsg = (4: VertexId, 5.0) val bout = new ByteArrayOutputStream - val outStrm = new DoubleAggMsgSerializer(conf).newInstance().serializeStream(bout) + val outStrm = new DoubleAggMsgSerializer().newInstance().serializeStream(bout) outStrm.writeObject(outMsg) outStrm.writeObject(outMsg) bout.flush() val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new DoubleAggMsgSerializer(conf).newInstance().deserializeStream(bin) + val inStrm = new DoubleAggMsgSerializer().newInstance().deserializeStream(bin) val inMsg1: (VertexId, Double) = inStrm.readObject() val inMsg2: (VertexId, Double) = inStrm.readObject() assert(outMsg === inMsg1) -- GitLab