diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index ef598ae41b1f69867042e0302c24e27fb7036c15..673f9a810dd7348eca15501287d1807892db6d45 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -33,8 +33,9 @@ import com.google.common.util.concurrent.ThreadFactoryBuilder
 
 import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
 
-import spark.serializer.SerializerInstance
+import spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
 import spark.deploy.SparkHadoopUtil
+import java.nio.ByteBuffer
 
 
 /**
@@ -68,6 +69,47 @@ private object Utils extends Logging {
     return ois.readObject.asInstanceOf[T]
   }
 
+  /** Serialize via nested stream using specific serializer */
+  def serializeViaNestedStream(os: OutputStream, ser: SerializerInstance)(f: SerializationStream => Unit) = {
+    val osWrapper = ser.serializeStream(new OutputStream {
+      def write(b: Int) = os.write(b)
+
+      override def write(b: Array[Byte], off: Int, len: Int) = os.write(b, off, len)
+    })
+    try {
+      f(osWrapper)
+    } finally {
+      osWrapper.close()
+    }
+  }
+
+  /** Deserialize via nested stream using specific serializer */
+  def deserializeViaNestedStream(is: InputStream, ser: SerializerInstance)(f: DeserializationStream => Unit) = {
+    val isWrapper = ser.deserializeStream(new InputStream {
+      def read(): Int = is.read()
+
+      override def read(b: Array[Byte], off: Int, len: Int): Int = is.read(b, off, len)
+    })
+    try {
+      f(isWrapper)
+    } finally {
+      isWrapper.close()
+    }
+  }
+
+  /**
+   * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}.
+   */
+  def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = {
+    if (bb.hasArray) {
+      out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
+    } else {
+      val bbval = new Array[Byte](bb.remaining())
+      bb.get(bbval)
+      out.write(bbval)
+    }
+  }
+
   def isAlpha(c: Char): Boolean = {
     (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')
   }
diff --git a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala
index 16ba0c26f8ee10414d5843b7ecbec520765ec0c0..33079cd53937d99062c2fd6ddc0e04a68476a007 100644
--- a/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala
@@ -20,13 +20,15 @@ package spark.rdd
 import scala.collection.immutable.NumericRange
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.Map
-import spark.{RDD, TaskContext, SparkContext, Partition}
+import spark._
+import java.io._
+import scala.Serializable
 
 private[spark] class ParallelCollectionPartition[T: ClassManifest](
-    val rddId: Long,
-    val slice: Int,
-    values: Seq[T])
-  extends Partition with Serializable {
+    var rddId: Long,
+    var slice: Int,
+    var values: Seq[T])
+    extends Partition with Serializable {
 
   def iterator: Iterator[T] = values.iterator
 
@@ -37,15 +39,49 @@ private[spark] class ParallelCollectionPartition[T: ClassManifest](
     case _ => false
   }
 
-  override val index: Int = slice
+  override def index: Int = slice
+
+  @throws(classOf[IOException])
+  private def writeObject(out: ObjectOutputStream): Unit = {
+
+    val sfactory = SparkEnv.get.serializer
+
+    // Treat java serializer with default action rather than going thru serialization, to avoid a
+    // separate serialization header.
+
+    sfactory match {
+      case js: JavaSerializer => out.defaultWriteObject()
+      case _ =>
+        out.writeLong(rddId)
+        out.writeInt(slice)
+
+        val ser = sfactory.newInstance()
+        Utils.serializeViaNestedStream(out, ser)(_.writeObject(values))
+    }
+  }
+
+  @throws(classOf[IOException])
+  private def readObject(in: ObjectInputStream): Unit = {
+
+    val sfactory = SparkEnv.get.serializer
+    sfactory match {
+      case js: JavaSerializer => in.defaultReadObject()
+      case _ =>
+        rddId = in.readLong()
+        slice = in.readInt()
+
+        val ser = sfactory.newInstance()
+        Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject())
+    }
+  }
 }
 
 private[spark] class ParallelCollectionRDD[T: ClassManifest](
     @transient sc: SparkContext,
     @transient data: Seq[T],
     numSlices: Int,
-    locationPrefs: Map[Int,Seq[String]])
-  extends RDD[T](sc, Nil) {
+    locationPrefs: Map[Int, Seq[String]])
+    extends RDD[T](sc, Nil) {
   // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets
   // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split
   // instead.
@@ -82,16 +118,17 @@ private object ParallelCollectionRDD {
           1
         }
         slice(new Range(
-            r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices)
+          r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices)
       }
       case r: Range => {
         (0 until numSlices).map(i => {
           val start = ((i * r.length.toLong) / numSlices).toInt
-          val end = (((i+1) * r.length.toLong) / numSlices).toInt
+          val end = (((i + 1) * r.length.toLong) / numSlices).toInt
           new Range(r.start + start * r.step, r.start + end * r.step, r.step)
         }).asInstanceOf[Seq[Seq[T]]]
       }
-      case nr: NumericRange[_] => { // For ranges of Long, Double, BigInteger, etc
+      case nr: NumericRange[_] => {
+        // For ranges of Long, Double, BigInteger, etc
         val slices = new ArrayBuffer[Seq[T]](numSlices)
         val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything
         var r = nr
@@ -102,10 +139,10 @@ private object ParallelCollectionRDD {
         slices
       }
       case _ => {
-        val array = seq.toArray  // To prevent O(n^2) operations for List etc
+        val array = seq.toArray // To prevent O(n^2) operations for List etc
         (0 until numSlices).map(i => {
           val start = ((i * array.length.toLong) / numSlices).toInt
-          val end = (((i+1) * array.length.toLong) / numSlices).toInt
+          val end = (((i + 1) * array.length.toLong) / numSlices).toInt
           array.slice(start, end).toSeq
         })
       }
diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala
index dc0621ea7ba5fb20ec6be60627d20d602790a1de..89793e0e8287839f62300488c0fe7ccdbcc25885 100644
--- a/core/src/main/scala/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/spark/scheduler/TaskResult.scala
@@ -21,6 +21,8 @@ import java.io._
 
 import scala.collection.mutable.Map
 import spark.executor.TaskMetrics
+import spark.{Utils, SparkEnv}
+import java.nio.ByteBuffer
 
 // Task result. Also contains updates to accumulator variables.
 // TODO: Use of distributed cache to return result is a hack to get around
@@ -30,7 +32,13 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics:
   def this() = this(null.asInstanceOf[T], null, null)
 
   override def writeExternal(out: ObjectOutput) {
-    out.writeObject(value)
+
+    val objectSer = SparkEnv.get.serializer.newInstance()
+    val bb = objectSer.serialize(value)
+
+    out.writeInt(bb.remaining())
+    Utils.writeByteBuffer(bb, out)
+
     out.writeInt(accumUpdates.size)
     for ((key, value) <- accumUpdates) {
       out.writeLong(key)
@@ -40,7 +48,14 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics:
   }
 
   override def readExternal(in: ObjectInput) {
-    value = in.readObject().asInstanceOf[T]
+
+    val objectSer = SparkEnv.get.serializer.newInstance()
+
+    val blen = in.readInt()
+    val byteVal = new Array[Byte](blen)
+    in.readFully(byteVal)
+    value = objectSer.deserialize(ByteBuffer.wrap(byteVal))
+
     val numUpdates = in.readInt
     if (numUpdates == 0) {
       accumUpdates = null
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
index d2110bd098e59b3963ff0819235f15c1ab5d0d3e..7f855cd345b7f6ab4ea600d0c9aed39f86574f87 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -92,7 +92,8 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
   val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
 
   // Serializer for closures and tasks.
-  val ser = SparkEnv.get.closureSerializer.newInstance()
+  val env = SparkEnv.get
+  val ser = env.closureSerializer.newInstance()
 
   val tasks = taskSet.tasks
   val numTasks = tasks.length
@@ -534,6 +535,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
   }
 
   override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+    SparkEnv.set(env)
     state match {
       case TaskState.FINISHED =>
         taskFinished(tid, state, serializedData)
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index bb0c836e86b2a39c6c0f39082806b77ce3265971..f274b1a767984524df7b0fba6d92f9045c6ce312 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -169,7 +169,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
     // Set the Spark execution environment for the worker thread
     SparkEnv.set(env)
     val ser = SparkEnv.get.closureSerializer.newInstance()
-    var attemptedTask: Option[Task[_]] = None
+    val objectSer = SparkEnv.get.serializer.newInstance()
+    var attemptedTask: Option[Task[_]] = None   
     val start = System.currentTimeMillis()
     var taskStart: Long = 0
     try {
@@ -193,9 +194,9 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
       // executor does. This is useful to catch serialization errors early
       // on in development (so when users move their local Spark programs
       // to the cluster, they don't get surprised by serialization errors).
-      val serResult = ser.serialize(result)
+      val serResult = objectSer.serialize(result)
       deserializedTask.metrics.get.resultSize = serResult.limit()
-      val resultToReturn = ser.deserialize[Any](serResult)
+      val resultToReturn = objectSer.deserialize[Any](serResult)
       val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
         ser.serialize(Accumulators.values))
       val serviceTime = System.currentTimeMillis() - taskStart
diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
index 4ab15532cf8ac03ed2c5dc386fc5ad39498e0181..c38eeb9e11eb96b23fad0efd906cfc927ce179af 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
@@ -42,7 +42,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
   val taskInfos = new HashMap[Long, TaskInfo]
   val numTasks = taskSet.tasks.size
   var numFinished = 0
-  val ser = SparkEnv.get.closureSerializer.newInstance()
+  val env = SparkEnv.get
+  val ser = env.closureSerializer.newInstance()
   val copiesRunning = new Array[Int](numTasks)
   val finished = new Array[Boolean](numTasks)
   val numFailures = new Array[Int](numTasks)
@@ -143,6 +144,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
   }
 
   override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+    SparkEnv.set(env)
     state match {
       case TaskState.FINISHED =>
         taskEnded(tid, state, serializedData)
diff --git a/core/src/test/scala/spark/KryoSerializerSuite.scala b/core/src/test/scala/spark/KryoSerializerSuite.scala
index 30d2d5282bc84bf887d25c5233020be9e7d06499..01390027c8de1bb5cfd5f1d8f8cc3f4a4c2a7d55 100644
--- a/core/src/test/scala/spark/KryoSerializerSuite.scala
+++ b/core/src/test/scala/spark/KryoSerializerSuite.scala
@@ -22,7 +22,9 @@ import scala.collection.mutable
 import org.scalatest.FunSuite
 import com.esotericsoftware.kryo._
 
-class KryoSerializerSuite extends FunSuite {
+import KryoTest._
+
+class KryoSerializerSuite extends FunSuite with SharedSparkContext {
   test("basic types") {
     val ser = (new KryoSerializer).newInstance()
     def check[T](t: T) {
@@ -124,6 +126,45 @@ class KryoSerializerSuite extends FunSuite {
 
     System.clearProperty("spark.kryo.registrator")
   }
+
+  test("kryo with collect") {
+    val control = 1 :: 2 :: Nil
+    val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_)).collect().map(_.x)
+    assert(control === result.toSeq)
+  }
+
+  test("kryo with parallelize") {
+    val control = 1 :: 2 :: Nil
+    val result = sc.parallelize(control.map(new ClassWithoutNoArgConstructor(_))).map(_.x).collect()
+    assert (control === result.toSeq)
+  }
+
+  test("kryo with reduce") {
+    val control = 1 :: 2 :: Nil
+    val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_))
+        .reduce((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x
+    assert(control.sum === result)
+  }
+
+  // TODO: this still doesn't work
+  ignore("kryo with fold") {
+    val control = 1 :: 2 :: Nil
+    val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_))
+        .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x
+    assert(10 + control.sum === result)
+  }
+
+  override def beforeAll() {
+    System.setProperty("spark.serializer", "spark.KryoSerializer")
+    System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName)
+    super.beforeAll()
+  }
+
+  override def afterAll() {
+    super.afterAll()
+    System.clearProperty("spark.kryo.registrator")
+    System.clearProperty("spark.serializer")
+  }
 }
 
 object KryoTest {