From baa3e633e18c47b12e79fe3ddc01fc8ec010f096 Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh <simonh@tw.ibm.com>
Date: Mon, 13 Jun 2016 19:59:53 -0700
Subject: [PATCH] [SPARK-15364][ML][PYSPARK] Implement PySpark picklers for
 ml.Vector and ml.Matrix under spark.ml.python

## What changes were proposed in this pull request?

Now we have PySpark picklers for new and old vector/matrix, individually. However, they are all implemented under `PythonMLlibAPI`. To separate spark.mllib from spark.ml, we should implement the picklers of new vector/matrix under `spark.ml.python` instead.

## How was this patch tested?
Existing tests.

Author: Liang-Chi Hsieh <simonh@tw.ibm.com>

Closes #13219 from viirya/pyspark-pickler-ml.
---
 .../org/apache/spark/ml/python/MLSerDe.scala  | 224 +++++++++++++
 .../mllib/api/python/PythonMLLibAPI.scala     | 309 ++++--------------
 .../apache/spark/ml/python/MLSerDeSuite.scala |  72 ++++
 python/pyspark/java_gateway.py                |   1 +
 python/pyspark/ml/base.py                     |   2 +-
 python/pyspark/ml/classification.py           |   2 +-
 python/pyspark/ml/clustering.py               |   2 +-
 python/pyspark/ml/common.py                   | 137 ++++++++
 python/pyspark/ml/evaluation.py               |   2 +-
 python/pyspark/ml/feature.py                  |   2 +-
 python/pyspark/ml/pipeline.py                 |   2 +-
 python/pyspark/ml/recommendation.py           |   2 +-
 python/pyspark/ml/regression.py               |   2 +-
 python/pyspark/ml/tests.py                    |  10 +-
 python/pyspark/ml/tuning.py                   |   2 +-
 python/pyspark/ml/util.py                     |   2 +-
 python/pyspark/ml/wrapper.py                  |   2 +-
 17 files changed, 518 insertions(+), 257 deletions(-)
 create mode 100644 mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala
 create mode 100644 mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala
 create mode 100644 python/pyspark/ml/common.py

diff --git a/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala b/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala
new file mode 100644
index 0000000000..1279c901c5
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.python
+
+import java.io.OutputStream
+import java.nio.{ByteBuffer, ByteOrder}
+import java.util.{ArrayList => JArrayList}
+
+import scala.collection.JavaConverters._
+
+import net.razorvine.pickle._
+
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.python.SerDeUtil
+import org.apache.spark.ml.linalg._
+import org.apache.spark.mllib.api.python.SerDeBase
+import org.apache.spark.rdd.RDD
+
+/**
+ * SerDe utility functions for pyspark.ml.
+ */
+private[spark] object MLSerDe extends SerDeBase with Serializable {
+
+  override val PYSPARK_PACKAGE = "pyspark.ml"
+
+  // Pickler for DenseVector
+  private[python] class DenseVectorPickler extends BasePickler[DenseVector] {
+
+    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+      val vector: DenseVector = obj.asInstanceOf[DenseVector]
+      val bytes = new Array[Byte](8 * vector.size)
+      val bb = ByteBuffer.wrap(bytes)
+      bb.order(ByteOrder.nativeOrder())
+      val db = bb.asDoubleBuffer()
+      db.put(vector.values)
+
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(bytes.length))
+      out.write(bytes)
+      out.write(Opcodes.TUPLE1)
+    }
+
+    def construct(args: Array[Object]): Object = {
+      require(args.length == 1)
+      if (args.length != 1) {
+        throw new PickleException("should be 1")
+      }
+      val bytes = getBytes(args(0))
+      val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
+      bb.order(ByteOrder.nativeOrder())
+      val db = bb.asDoubleBuffer()
+      val ans = new Array[Double](bytes.length / 8)
+      db.get(ans)
+      Vectors.dense(ans)
+    }
+  }
+
+  // Pickler for DenseMatrix
+  private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] {
+
+    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+      val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
+      val bytes = new Array[Byte](8 * m.values.length)
+      val order = ByteOrder.nativeOrder()
+      val isTransposed = if (m.isTransposed) 1 else 0
+      ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)
+
+      out.write(Opcodes.MARK)
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(m.numRows))
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(m.numCols))
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(bytes.length))
+      out.write(bytes)
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(isTransposed))
+      out.write(Opcodes.TUPLE)
+    }
+
+    def construct(args: Array[Object]): Object = {
+      if (args.length != 4) {
+        throw new PickleException("should be 4")
+      }
+      val bytes = getBytes(args(2))
+      val n = bytes.length / 8
+      val values = new Array[Double](n)
+      val order = ByteOrder.nativeOrder()
+      ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
+      val isTransposed = args(3).asInstanceOf[Int] == 1
+      new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed)
+    }
+  }
+
+  // Pickler for SparseMatrix
+  private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] {
+
+    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+      val s = obj.asInstanceOf[SparseMatrix]
+      val order = ByteOrder.nativeOrder()
+
+      val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length)
+      val indicesBytes = new Array[Byte](4 * s.rowIndices.length)
+      val valuesBytes = new Array[Byte](8 * s.values.length)
+      val isTransposed = if (s.isTransposed) 1 else 0
+      ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs)
+      ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices)
+      ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values)
+
+      out.write(Opcodes.MARK)
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(s.numRows))
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(s.numCols))
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length))
+      out.write(colPtrsBytes)
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(indicesBytes.length))
+      out.write(indicesBytes)
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(valuesBytes.length))
+      out.write(valuesBytes)
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(isTransposed))
+      out.write(Opcodes.TUPLE)
+    }
+
+    def construct(args: Array[Object]): Object = {
+      if (args.length != 6) {
+        throw new PickleException("should be 6")
+      }
+      val order = ByteOrder.nativeOrder()
+      val colPtrsBytes = getBytes(args(2))
+      val indicesBytes = getBytes(args(3))
+      val valuesBytes = getBytes(args(4))
+      val colPtrs = new Array[Int](colPtrsBytes.length / 4)
+      val rowIndices = new Array[Int](indicesBytes.length / 4)
+      val values = new Array[Double](valuesBytes.length / 8)
+      ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs)
+      ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices)
+      ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values)
+      val isTransposed = args(5).asInstanceOf[Int] == 1
+      new SparseMatrix(
+        args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, rowIndices, values,
+        isTransposed)
+    }
+  }
+
+  // Pickler for SparseVector
+  private[python] class SparseVectorPickler extends BasePickler[SparseVector] {
+
+    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+      val v: SparseVector = obj.asInstanceOf[SparseVector]
+      val n = v.indices.length
+      val indiceBytes = new Array[Byte](4 * n)
+      val order = ByteOrder.nativeOrder()
+      ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices)
+      val valueBytes = new Array[Byte](8 * n)
+      ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values)
+
+      out.write(Opcodes.BININT)
+      out.write(PickleUtils.integer_to_bytes(v.size))
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(indiceBytes.length))
+      out.write(indiceBytes)
+      out.write(Opcodes.BINSTRING)
+      out.write(PickleUtils.integer_to_bytes(valueBytes.length))
+      out.write(valueBytes)
+      out.write(Opcodes.TUPLE3)
+    }
+
+    def construct(args: Array[Object]): Object = {
+      if (args.length != 3) {
+        throw new PickleException("should be 3")
+      }
+      val size = args(0).asInstanceOf[Int]
+      val indiceBytes = getBytes(args(1))
+      val valueBytes = getBytes(args(2))
+      val n = indiceBytes.length / 4
+      val indices = new Array[Int](n)
+      val values = new Array[Double](n)
+      if (n > 0) {
+        val order = ByteOrder.nativeOrder()
+        ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices)
+        ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values)
+      }
+      new SparseVector(size, indices, values)
+    }
+  }
+
+  var initialized = false
+  // This should be called before trying to serialize any above classes
+  // In cluster mode, this should be put in the closure
+  override def initialize(): Unit = {
+    SerDeUtil.initialize()
+    synchronized {
+      if (!initialized) {
+        new DenseVectorPickler().register()
+        new DenseMatrixPickler().register()
+        new SparseMatrixPickler().register()
+        new SparseVectorPickler().register()
+        initialized = true
+      }
+    }
+  }
+  // will not called in Executor automatically
+  initialize()
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index e43469bf1c..7df61601fb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -30,7 +30,6 @@ import net.razorvine.pickle._
 
 import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.api.python.SerDeUtil
-import org.apache.spark.ml.linalg.{DenseMatrix => NewDenseMatrix, DenseVector => NewDenseVector, SparseMatrix => NewSparseMatrix, SparseVector => NewSparseVector, Vectors => NewVectors}
 import org.apache.spark.mllib.classification._
 import org.apache.spark.mllib.clustering._
 import org.apache.spark.mllib.evaluation.RankingMetrics
@@ -1205,23 +1204,21 @@ private[python] class PythonMLLibAPI extends Serializable {
 }
 
 /**
- * SerDe utility functions for PythonMLLibAPI.
+ * Basic SerDe utility class.
  */
-private[spark] object SerDe extends Serializable {
+private[spark] abstract class SerDeBase {
 
-  val PYSPARK_PACKAGE = "pyspark.mllib"
-  val PYSPARK_ML_PACKAGE = "pyspark.ml"
+  val PYSPARK_PACKAGE: String
+  def initialize(): Unit
 
   /**
    * Base class used for pickle
    */
-  private[python] abstract class BasePickler[T: ClassTag]
+  private[spark] abstract class BasePickler[T: ClassTag]
     extends IObjectPickler with IObjectConstructor {
 
-    protected def packageName: String = PYSPARK_PACKAGE
-
     private val cls = implicitly[ClassTag[T]].runtimeClass
-    private val module = packageName + "." + cls.getName.split('.')(4)
+    private val module = PYSPARK_PACKAGE + "." + cls.getName.split('.')(4)
     private val name = cls.getSimpleName
 
     // register this to Pickler and Unpickler
@@ -1268,45 +1265,73 @@ private[spark] object SerDe extends Serializable {
     private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler)
   }
 
-  // Pickler for (mllib) DenseVector
-  private[python] class DenseVectorPickler extends BasePickler[DenseVector] {
+  def dumps(obj: AnyRef): Array[Byte] = {
+    obj match {
+      // Pickler in Python side cannot deserialize Scala Array normally. See SPARK-12834.
+      case array: Array[_] => new Pickler().dumps(array.toSeq.asJava)
+      case _ => new Pickler().dumps(obj)
+    }
+  }
 
-    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
-      val vector: DenseVector = obj.asInstanceOf[DenseVector]
-      val bytes = new Array[Byte](8 * vector.size)
-      val bb = ByteBuffer.wrap(bytes)
-      bb.order(ByteOrder.nativeOrder())
-      val db = bb.asDoubleBuffer()
-      db.put(vector.values)
+  def loads(bytes: Array[Byte]): AnyRef = {
+    new Unpickler().loads(bytes)
+  }
 
-      out.write(Opcodes.BINSTRING)
-      out.write(PickleUtils.integer_to_bytes(bytes.length))
-      out.write(bytes)
-      out.write(Opcodes.TUPLE1)
+  /* convert object into Tuple */
+  def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = {
+    rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int]))
+  }
+
+  /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
+  def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = {
+    rdd.map(x => Array(x._1, x._2))
+  }
+
+  /**
+   * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
+   * PySpark.
+   */
+  def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
+    jRDD.rdd.mapPartitions { iter =>
+      initialize()  // let it called in executor
+      new SerDeUtil.AutoBatchedPickler(iter)
     }
+  }
 
-    def construct(args: Array[Object]): Object = {
-      require(args.length == 1)
-      if (args.length != 1) {
-        throw new PickleException("should be 1")
+  /**
+   * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
+   */
+  def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
+    pyRDD.rdd.mapPartitions { iter =>
+      initialize()  // let it called in executor
+      val unpickle = new Unpickler
+      iter.flatMap { row =>
+        val obj = unpickle.loads(row)
+        if (batched) {
+          obj match {
+            case list: JArrayList[_] => list.asScala
+            case arr: Array[_] => arr
+          }
+        } else {
+          Seq(obj)
+        }
       }
-      val bytes = getBytes(args(0))
-      val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
-      bb.order(ByteOrder.nativeOrder())
-      val db = bb.asDoubleBuffer()
-      val ans = new Array[Double](bytes.length / 8)
-      db.get(ans)
-      Vectors.dense(ans)
-    }
+    }.toJavaRDD()
   }
+}
 
-  // Pickler for (new) DenseVector
-  private[python] class NewDenseVectorPickler extends BasePickler[NewDenseVector] {
+/**
+ * SerDe utility functions for PythonMLLibAPI.
+ */
+private[spark] object SerDe extends SerDeBase with Serializable {
+
+  override val PYSPARK_PACKAGE = "pyspark.mllib"
 
-    override protected def packageName = PYSPARK_ML_PACKAGE
+  // Pickler for DenseVector
+  private[python] class DenseVectorPickler extends BasePickler[DenseVector] {
 
     def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
-      val vector: NewDenseVector = obj.asInstanceOf[NewDenseVector]
+      val vector: DenseVector = obj.asInstanceOf[DenseVector]
       val bytes = new Array[Byte](8 * vector.size)
       val bb = ByteBuffer.wrap(bytes)
       bb.order(ByteOrder.nativeOrder())
@@ -1330,11 +1355,11 @@ private[spark] object SerDe extends Serializable {
       val db = bb.asDoubleBuffer()
       val ans = new Array[Double](bytes.length / 8)
       db.get(ans)
-      NewVectors.dense(ans)
+      Vectors.dense(ans)
     }
   }
 
-  // Pickler for (mllib) DenseMatrix
+  // Pickler for DenseMatrix
   private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] {
 
     def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
@@ -1371,46 +1396,7 @@ private[spark] object SerDe extends Serializable {
     }
   }
 
-  // Pickler for (new) DenseMatrix
-  private[python] class NewDenseMatrixPickler extends BasePickler[NewDenseMatrix] {
-
-    override protected def packageName = PYSPARK_ML_PACKAGE
-
-    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
-      val m: NewDenseMatrix = obj.asInstanceOf[NewDenseMatrix]
-      val bytes = new Array[Byte](8 * m.values.length)
-      val order = ByteOrder.nativeOrder()
-      val isTransposed = if (m.isTransposed) 1 else 0
-      ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)
-
-      out.write(Opcodes.MARK)
-      out.write(Opcodes.BININT)
-      out.write(PickleUtils.integer_to_bytes(m.numRows))
-      out.write(Opcodes.BININT)
-      out.write(PickleUtils.integer_to_bytes(m.numCols))
-      out.write(Opcodes.BINSTRING)
-      out.write(PickleUtils.integer_to_bytes(bytes.length))
-      out.write(bytes)
-      out.write(Opcodes.BININT)
-      out.write(PickleUtils.integer_to_bytes(isTransposed))
-      out.write(Opcodes.TUPLE)
-    }
-
-    def construct(args: Array[Object]): Object = {
-      if (args.length != 4) {
-        throw new PickleException("should be 4")
-      }
-      val bytes = getBytes(args(2))
-      val n = bytes.length / 8
-      val values = new Array[Double](n)
-      val order = ByteOrder.nativeOrder()
-      ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
-      val isTransposed = args(3).asInstanceOf[Int] == 1
-      new NewDenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed)
-    }
-  }
-
-  // Pickler for (mllib) SparseMatrix
+  // Pickler for SparseMatrix
   private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] {
 
     def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
@@ -1465,64 +1451,7 @@ private[spark] object SerDe extends Serializable {
     }
   }
 
-  // Pickler for (new) SparseMatrix
-  private[python] class NewSparseMatrixPickler extends BasePickler[NewSparseMatrix] {
-
-    override protected def packageName = PYSPARK_ML_PACKAGE
-
-    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
-      val s = obj.asInstanceOf[NewSparseMatrix]
-      val order = ByteOrder.nativeOrder()
-
-      val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length)
-      val indicesBytes = new Array[Byte](4 * s.rowIndices.length)
-      val valuesBytes = new Array[Byte](8 * s.values.length)
-      val isTransposed = if (s.isTransposed) 1 else 0
-      ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs)
-      ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices)
-      ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values)
-
-      out.write(Opcodes.MARK)
-      out.write(Opcodes.BININT)
-      out.write(PickleUtils.integer_to_bytes(s.numRows))
-      out.write(Opcodes.BININT)
-      out.write(PickleUtils.integer_to_bytes(s.numCols))
-      out.write(Opcodes.BINSTRING)
-      out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length))
-      out.write(colPtrsBytes)
-      out.write(Opcodes.BINSTRING)
-      out.write(PickleUtils.integer_to_bytes(indicesBytes.length))
-      out.write(indicesBytes)
-      out.write(Opcodes.BINSTRING)
-      out.write(PickleUtils.integer_to_bytes(valuesBytes.length))
-      out.write(valuesBytes)
-      out.write(Opcodes.BININT)
-      out.write(PickleUtils.integer_to_bytes(isTransposed))
-      out.write(Opcodes.TUPLE)
-    }
-
-    def construct(args: Array[Object]): Object = {
-      if (args.length != 6) {
-        throw new PickleException("should be 6")
-      }
-      val order = ByteOrder.nativeOrder()
-      val colPtrsBytes = getBytes(args(2))
-      val indicesBytes = getBytes(args(3))
-      val valuesBytes = getBytes(args(4))
-      val colPtrs = new Array[Int](colPtrsBytes.length / 4)
-      val rowIndices = new Array[Int](indicesBytes.length / 4)
-      val values = new Array[Double](valuesBytes.length / 8)
-      ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs)
-      ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices)
-      ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values)
-      val isTransposed = args(5).asInstanceOf[Int] == 1
-      new NewSparseMatrix(
-        args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, rowIndices, values,
-        isTransposed)
-    }
-  }
-
-  // Pickler for (mllib) SparseVector
+  // Pickler for SparseVector
   private[python] class SparseVectorPickler extends BasePickler[SparseVector] {
 
     def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
@@ -1564,50 +1493,6 @@ private[spark] object SerDe extends Serializable {
     }
   }
 
-  // Pickler for (new) SparseVector
-  private[python] class NewSparseVectorPickler extends BasePickler[NewSparseVector] {
-
-    override protected def packageName = PYSPARK_ML_PACKAGE
-
-    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
-      val v: NewSparseVector = obj.asInstanceOf[NewSparseVector]
-      val n = v.indices.length
-      val indiceBytes = new Array[Byte](4 * n)
-      val order = ByteOrder.nativeOrder()
-      ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices)
-      val valueBytes = new Array[Byte](8 * n)
-      ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values)
-
-      out.write(Opcodes.BININT)
-      out.write(PickleUtils.integer_to_bytes(v.size))
-      out.write(Opcodes.BINSTRING)
-      out.write(PickleUtils.integer_to_bytes(indiceBytes.length))
-      out.write(indiceBytes)
-      out.write(Opcodes.BINSTRING)
-      out.write(PickleUtils.integer_to_bytes(valueBytes.length))
-      out.write(valueBytes)
-      out.write(Opcodes.TUPLE3)
-    }
-
-    def construct(args: Array[Object]): Object = {
-      if (args.length != 3) {
-        throw new PickleException("should be 3")
-      }
-      val size = args(0).asInstanceOf[Int]
-      val indiceBytes = getBytes(args(1))
-      val valueBytes = getBytes(args(2))
-      val n = indiceBytes.length / 4
-      val indices = new Array[Int](n)
-      val values = new Array[Double](n)
-      if (n > 0) {
-        val order = ByteOrder.nativeOrder()
-        ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices)
-        ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values)
-      }
-      new NewSparseVector(size, indices, values)
-    }
-  }
-
   // Pickler for MLlib LabeledPoint
   private[python] class LabeledPointPickler extends BasePickler[LabeledPoint] {
 
@@ -1654,7 +1539,7 @@ private[spark] object SerDe extends Serializable {
   var initialized = false
   // This should be called before trying to serialize any above classes
   // In cluster mode, this should be put in the closure
-  def initialize(): Unit = {
+  override def initialize(): Unit = {
     SerDeUtil.initialize()
     synchronized {
       if (!initialized) {
@@ -1662,10 +1547,6 @@ private[spark] object SerDe extends Serializable {
         new DenseMatrixPickler().register()
         new SparseMatrixPickler().register()
         new SparseVectorPickler().register()
-        new NewDenseVectorPickler().register()
-        new NewDenseMatrixPickler().register()
-        new NewSparseMatrixPickler().register()
-        new NewSparseVectorPickler().register()
         new LabeledPointPickler().register()
         new RatingPickler().register()
         initialized = true
@@ -1674,58 +1555,4 @@ private[spark] object SerDe extends Serializable {
   }
   // will not called in Executor automatically
   initialize()
-
-  def dumps(obj: AnyRef): Array[Byte] = {
-    obj match {
-      // Pickler in Python side cannot deserialize Scala Array normally. See SPARK-12834.
-      case array: Array[_] => new Pickler().dumps(array.toSeq.asJava)
-      case _ => new Pickler().dumps(obj)
-    }
-  }
-
-  def loads(bytes: Array[Byte]): AnyRef = {
-    new Unpickler().loads(bytes)
-  }
-
-  /* convert object into Tuple */
-  def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = {
-    rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int]))
-  }
-
-  /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
-  def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = {
-    rdd.map(x => Array(x._1, x._2))
-  }
-
-  /**
-   * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
-   * PySpark.
-   */
-  def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
-    jRDD.rdd.mapPartitions { iter =>
-      initialize()  // let it called in executor
-      new SerDeUtil.AutoBatchedPickler(iter)
-    }
-  }
-
-  /**
-   * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
-   */
-  def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
-    pyRDD.rdd.mapPartitions { iter =>
-      initialize()  // let it called in executor
-      val unpickle = new Unpickler
-      iter.flatMap { row =>
-        val obj = unpickle.loads(row)
-        if (batched) {
-          obj match {
-            case list: JArrayList[_] => list.asScala
-            case arr: Array[_] => arr
-          }
-        } else {
-          Seq(obj)
-        }
-      }
-    }.toJavaRDD()
-  }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala
new file mode 100644
index 0000000000..5eaef9aabd
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/python/MLSerDeSuite.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.python
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, Vectors}
+
+class MLSerDeSuite extends SparkFunSuite {
+
+  MLSerDe.initialize()
+
+  test("pickle vector") {
+    val vectors = Seq(
+      Vectors.dense(Array.empty[Double]),
+      Vectors.dense(0.0),
+      Vectors.dense(0.0, -2.0),
+      Vectors.sparse(0, Array.empty[Int], Array.empty[Double]),
+      Vectors.sparse(1, Array.empty[Int], Array.empty[Double]),
+      Vectors.sparse(2, Array(1), Array(-2.0)))
+    vectors.foreach { v =>
+      val u = MLSerDe.loads(MLSerDe.dumps(v))
+      assert(u.getClass === v.getClass)
+      assert(u === v)
+    }
+  }
+
+  test("pickle double") {
+    for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) {
+      val deser = MLSerDe.loads(MLSerDe.dumps(x.asInstanceOf[AnyRef])).asInstanceOf[Double]
+      // We use `equals` here for comparison because we cannot use `==` for NaN
+      assert(x.equals(deser))
+    }
+  }
+
+  test("pickle matrix") {
+    val values = Array[Double](0, 1.2, 3, 4.56, 7, 8)
+    val matrix = Matrices.dense(2, 3, values)
+    val nm = MLSerDe.loads(MLSerDe.dumps(matrix)).asInstanceOf[DenseMatrix]
+    assert(matrix === nm)
+
+    // Test conversion for empty matrix
+    val empty = Array[Double]()
+    val emptyMatrix = Matrices.dense(0, 0, empty)
+    val ne = MLSerDe.loads(MLSerDe.dumps(emptyMatrix)).asInstanceOf[DenseMatrix]
+    assert(emptyMatrix == ne)
+
+    val sm = new SparseMatrix(3, 2, Array(0, 1, 3), Array(1, 0, 2), Array(0.9, 1.2, 3.4))
+    val nsm = MLSerDe.loads(MLSerDe.dumps(sm)).asInstanceOf[SparseMatrix]
+    assert(sm.toArray === nsm.toArray)
+
+    val smt = new SparseMatrix(
+      3, 3, Array(0, 2, 3, 5), Array(0, 2, 1, 0, 2), Array(0.9, 1.2, 3.4, 5.7, 8.9),
+      isTransposed = true)
+    val nsmt = MLSerDe.loads(MLSerDe.dumps(smt)).asInstanceOf[SparseMatrix]
+    assert(smt.toArray === nsmt.toArray)
+  }
+}
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index cd4c55f79f..527ca82d31 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -116,6 +116,7 @@ def launch_gateway():
     java_import(gateway.jvm, "org.apache.spark.SparkConf")
     java_import(gateway.jvm, "org.apache.spark.api.java.*")
     java_import(gateway.jvm, "org.apache.spark.api.python.*")
+    java_import(gateway.jvm, "org.apache.spark.ml.python.*")
     java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
     # TODO(davies): move into sql
     java_import(gateway.jvm, "org.apache.spark.sql.*")
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index a7a58e17a4..339e5d6af5 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -19,7 +19,7 @@ from abc import ABCMeta, abstractmethod
 
 from pyspark import since
 from pyspark.ml.param import Params
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 
 
 @inherit_doc
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 77badebeb4..121b9262dd 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -26,7 +26,7 @@ from pyspark.ml.regression import (
 from pyspark.ml.util import *
 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
 from pyspark.ml.wrapper import JavaWrapper
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 from pyspark.sql import DataFrame
 from pyspark.sql.functions import udf, when
 from pyspark.sql.types import ArrayType, DoubleType
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 92df19e804..75d9a0e8ca 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -19,7 +19,7 @@ from pyspark import since, keyword_only
 from pyspark.ml.util import *
 from pyspark.ml.wrapper import JavaEstimator, JavaModel
 from pyspark.ml.param.shared import *
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 
 __all__ = ['BisectingKMeans', 'BisectingKMeansModel',
            'KMeans', 'KMeansModel',
diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py
new file mode 100644
index 0000000000..256e91e141
--- /dev/null
+++ b/python/pyspark/ml/common.py
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+if sys.version >= '3':
+    long = int
+    unicode = str
+
+import py4j.protocol
+from py4j.protocol import Py4JJavaError
+from py4j.java_gateway import JavaObject
+from py4j.java_collections import ListConverter, JavaArray, JavaList
+
+from pyspark import RDD, SparkContext
+from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+from pyspark.sql import DataFrame, SQLContext
+
+# Hack for support float('inf') in Py4j
+_old_smart_decode = py4j.protocol.smart_decode
+
+_float_str_mapping = {
+    'nan': 'NaN',
+    'inf': 'Infinity',
+    '-inf': '-Infinity',
+}
+
+
+def _new_smart_decode(obj):
+    if isinstance(obj, float):
+        s = str(obj)
+        return _float_str_mapping.get(s, s)
+    return _old_smart_decode(obj)
+
+py4j.protocol.smart_decode = _new_smart_decode
+
+
+_picklable_classes = [
+    'SparseVector',
+    'DenseVector',
+    'DenseMatrix',
+]
+
+
+# this will call the ML version of pythonToJava()
+def _to_java_object_rdd(rdd):
+    """ Return an JavaRDD of Object by unpickling
+
+    It will convert each Python object into Java object by Pyrolite, whenever the
+    RDD is serialized in batch or not.
+    """
+    rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
+    return rdd.ctx._jvm.MLSerDe.pythonToJava(rdd._jrdd, True)
+
+
+def _py2java(sc, obj):
+    """ Convert Python object into Java """
+    if isinstance(obj, RDD):
+        obj = _to_java_object_rdd(obj)
+    elif isinstance(obj, DataFrame):
+        obj = obj._jdf
+    elif isinstance(obj, SparkContext):
+        obj = obj._jsc
+    elif isinstance(obj, list):
+        obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client)
+    elif isinstance(obj, JavaObject):
+        pass
+    elif isinstance(obj, (int, long, float, bool, bytes, unicode)):
+        pass
+    else:
+        data = bytearray(PickleSerializer().dumps(obj))
+        obj = sc._jvm.MLSerDe.loads(data)
+    return obj
+
+
+def _java2py(sc, r, encoding="bytes"):
+    if isinstance(r, JavaObject):
+        clsName = r.getClass().getSimpleName()
+        # convert RDD into JavaRDD
+        if clsName != 'JavaRDD' and clsName.endswith("RDD"):
+            r = r.toJavaRDD()
+            clsName = 'JavaRDD'
+
+        if clsName == 'JavaRDD':
+            jrdd = sc._jvm.MLSerDe.javaToPython(r)
+            return RDD(jrdd, sc)
+
+        if clsName == 'Dataset':
+            return DataFrame(r, SQLContext.getOrCreate(sc))
+
+        if clsName in _picklable_classes:
+            r = sc._jvm.MLSerDe.dumps(r)
+        elif isinstance(r, (JavaArray, JavaList)):
+            try:
+                r = sc._jvm.MLSerDe.dumps(r)
+            except Py4JJavaError:
+                pass  # not pickable
+
+    if isinstance(r, (bytearray, bytes)):
+        r = PickleSerializer().loads(bytes(r), encoding=encoding)
+    return r
+
+
+def callJavaFunc(sc, func, *args):
+    """ Call Java Function """
+    args = [_py2java(sc, a) for a in args]
+    return _java2py(sc, func(*args))
+
+
+def inherit_doc(cls):
+    """
+    A decorator that makes a class inherit documentation from its parents.
+    """
+    for name, func in vars(cls).items():
+        # only inherit docstring for public functions
+        if name.startswith("_"):
+            continue
+        if not func.__doc__:
+            for parent in cls.__bases__:
+                parent_func = getattr(parent, name, None)
+                if parent_func and getattr(parent_func, "__doc__", None):
+                    func.__doc__ = parent_func.__doc__
+                    break
+    return cls
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index cd071f1b7c..1fe8772da7 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -21,7 +21,7 @@ from pyspark import since, keyword_only
 from pyspark.ml.wrapper import JavaParams
 from pyspark.ml.param import Param, Params, TypeConverters
 from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 
 __all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
            'MulticlassClassificationEvaluator']
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index ca77ac395d..a28764a752 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -25,7 +25,7 @@ from pyspark.ml.linalg import _convert_to_vector
 from pyspark.ml.param.shared import *
 from pyspark.ml.util import JavaMLReadable, JavaMLWritable
 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 
 __all__ = ['Binarizer',
            'Bucketizer',
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 0777527134..a48f4bb2ad 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -25,7 +25,7 @@ from pyspark.ml import Estimator, Model, Transformer
 from pyspark.ml.param import Param, Params
 from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable
 from pyspark.ml.wrapper import JavaParams
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 
 
 @inherit_doc
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index 1778bfe938..0a7096794d 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -19,7 +19,7 @@ from pyspark import since, keyword_only
 from pyspark.ml.util import *
 from pyspark.ml.wrapper import JavaEstimator, JavaModel
 from pyspark.ml.param.shared import *
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 
 
 __all__ = ['ALS', 'ALSModel']
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 7c79ab73c7..db31993f0f 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -21,7 +21,7 @@ from pyspark import since, keyword_only
 from pyspark.ml.param.shared import *
 from pyspark.ml.util import *
 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 from pyspark.sql import DataFrame
 
 
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 4358175a57..981ed9dda0 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -61,7 +61,7 @@ from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \
     GeneralizedLinearRegression
 from pyspark.ml.tuning import *
 from pyspark.ml.wrapper import JavaParams
-from pyspark.mllib.common import _java2py
+from pyspark.ml.common import _java2py
 from pyspark.serializers import PickleSerializer
 from pyspark.sql import DataFrame, Row, SparkSession
 from pyspark.sql.functions import rand
@@ -1195,12 +1195,12 @@ class VectorTests(MLlibTestCase):
 
     def _test_serialize(self, v):
         self.assertEqual(v, ser.loads(ser.dumps(v)))
-        jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v)))
-        nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec)))
+        jvec = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(v)))
+        nv = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvec)))
         self.assertEqual(v, nv)
         vs = [v] * 100
-        jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs)))
-        nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs)))
+        jvecs = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(vs)))
+        nvs = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvecs)))
         self.assertEqual(vs, nvs)
 
     def test_serialize(self):
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index fe87b6cdb9..f857c5e8c8 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -25,7 +25,7 @@ from pyspark.ml.param import Params, Param, TypeConverters
 from pyspark.ml.param.shared import HasSeed
 from pyspark.ml.wrapper import JavaParams
 from pyspark.sql.functions import rand
-from pyspark.mllib.common import inherit_doc, _py2java
+from pyspark.ml.common import inherit_doc, _py2java
 
 __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
            'TrainValidationSplitModel']
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 9d28823196..4a31a29809 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -23,7 +23,7 @@ if sys.version > '3':
     unicode = str
 
 from pyspark import SparkContext, since
-from pyspark.mllib.common import inherit_doc
+from pyspark.ml.common import inherit_doc
 
 
 def _jvm():
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index fef0040faf..25c44b7533 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -22,7 +22,7 @@ from pyspark.sql import DataFrame
 from pyspark.ml import Estimator, Transformer, Model
 from pyspark.ml.param import Params
 from pyspark.ml.util import _jvm
-from pyspark.mllib.common import inherit_doc, _java2py, _py2java
+from pyspark.ml.common import inherit_doc, _java2py, _py2java
 
 
 class JavaWrapper(object):
-- 
GitLab