Skip to content
Snippets Groups Projects
Commit 3bd31294 authored by Sandeep's avatar Sandeep Committed by Matei Zaharia
Browse files

SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining

Author: Sandeep <sandeep@techaddict.me>

Closes #356 from techaddict/1428 and squashes the following commits:

3bdf5f6 [Sandeep] SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining
parent 79820fe8
No related branches found
No related tags found
No related merge requests found
......@@ -15,8 +15,9 @@
# limitations under the License.
#
from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype
from pyspark import SparkContext, RDD
import numpy as np
from pyspark.serializers import Serializer
import struct
......@@ -47,13 +48,22 @@ def _deserialize_byte_array(shape, ba, offset):
return ar.copy()
def _serialize_double_vector(v):
"""Serialize a double vector into a mutually understood format."""
"""Serialize a double vector into a mutually understood format.
>>> x = array([1,2,3])
>>> y = _deserialize_double_vector(_serialize_double_vector(x))
>>> array_equal(y, array([1.0, 2.0, 3.0]))
True
"""
if type(v) != ndarray:
raise TypeError("_serialize_double_vector called on a %s; "
"wanted ndarray" % type(v))
"""complex is only datatype that can't be converted to float64"""
if issubdtype(v.dtype, complex):
raise TypeError("_serialize_double_vector called on a %s; "
"wanted ndarray" % type(v))
if v.dtype != float64:
raise TypeError("_serialize_double_vector called on an ndarray of %s; "
"wanted ndarray of float64" % v.dtype)
v = v.astype(float64)
if v.ndim != 1:
raise TypeError("_serialize_double_vector called on a %ddarray; "
"wanted a 1darray" % v.ndim)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment