diff --git a/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala b/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala index 28a8accb333e5c10c45d0f9e8ad186b2afd31dce..9cfd41407f97b78bac1fbe268e62e30327975f8a 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableHyperLogLog.scala @@ -17,26 +17,27 @@ package org.apache.spark.util -import java.io.{ObjectOutputStream, ObjectInputStream} +import java.io.{Externalizable, ObjectOutput, ObjectInput} import com.clearspring.analytics.stream.cardinality.{ICardinality, HyperLogLog} /** - * A wrapper around com.clearspring.analytics.stream.cardinality.HyperLogLog that is serializable. + * A wrapper around [[com.clearspring.analytics.stream.cardinality.HyperLogLog]] that is serializable. */ private[spark] -class SerializableHyperLogLog(@transient var value: ICardinality) extends Serializable { +class SerializableHyperLogLog(var value: ICardinality) extends Externalizable { + def this() = this(null) // For deserialization def merge(other: SerializableHyperLogLog) = new SerializableHyperLogLog(value.merge(other.value)) - private def readObject(in: ObjectInputStream) { + def readExternal(in: ObjectInput) { val byteLength = in.readInt() val bytes = new Array[Byte](byteLength) in.readFully(bytes) value = HyperLogLog.Builder.build(bytes) } - private def writeObject(out: ObjectOutputStream) { + def writeExternal(out: ObjectOutput) { val bytes = value.getBytes() out.writeInt(bytes.length) out.write(bytes) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index c016c5117149fabcf03e71f6fdd755c7bb60b3cd..18529710fecded15608379d48ca0a4d948db5a7d 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -172,6 +172,10 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { assert (sc.parallelize( Array((1, 11), (2, 22), (3, 33)) ).collect().head === (1, 11)) } + test("kryo with SerializableHyperLogLog") { + assert(sc.parallelize( Array(1, 2, 3, 2, 3, 3, 2, 3, 1) ).countDistinct(0.01) === 3) + } + test("kryo with reduce") { val control = 1 :: 2 :: Nil val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_))