From 3bd312940e2f5250edaf3e88d6c23de25bb1d0a9 Mon Sep 17 00:00:00 2001
From: Sandeep <sandeep@techaddict.me>
Date: Thu, 10 Apr 2014 11:17:41 -0700
Subject: [PATCH] SPARK-1428: MLlib should convert non-float64 NumPy arrays to
 float64 instead of complaining

Author: Sandeep <sandeep@techaddict.me>

Closes #356 from techaddict/1428 and squashes the following commits:

3bdf5f6 [Sandeep] SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining
---
 python/pyspark/mllib/_common.py | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index 20a0e309d1..7ef251d24c 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -15,8 +15,9 @@
 # limitations under the License.
 #
 
-from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
+from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype
 from pyspark import SparkContext, RDD
+import numpy as np
 
 from pyspark.serializers import Serializer
 import struct
@@ -47,13 +48,22 @@ def _deserialize_byte_array(shape, ba, offset):
     return ar.copy()
 
 def _serialize_double_vector(v):
-    """Serialize a double vector into a mutually understood format."""
+    """Serialize a double vector into a mutually understood format.
+
+    >>> x = array([1,2,3])
+    >>> y = _deserialize_double_vector(_serialize_double_vector(x))
+    >>> array_equal(y, array([1.0, 2.0, 3.0]))
+    True
+    """
     if type(v) != ndarray:
         raise TypeError("_serialize_double_vector called on a %s; "
                 "wanted ndarray" % type(v))
+    """complex is only datatype that can't be converted to float64"""
+    if issubdtype(v.dtype, complex):
+        raise TypeError("_serialize_double_vector called on a %s; "
+                "wanted ndarray" % type(v))
     if v.dtype != float64:
-        raise TypeError("_serialize_double_vector called on an ndarray of %s; "
-                "wanted ndarray of float64" % v.dtype)
+        v = v.astype(float64)
     if v.ndim != 1:
         raise TypeError("_serialize_double_vector called on a %ddarray; "
                 "wanted a 1darray" % v.ndim)
-- 
GitLab