Skip to content
Snippets Groups Projects
Commit 3bd31294 authored by Sandeep's avatar Sandeep Committed by Matei Zaharia
Browse files

SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining

Author: Sandeep <sandeep@techaddict.me>

Closes #356 from techaddict/1428 and squashes the following commits:

3bdf5f6 [Sandeep] SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining
parent 79820fe8
No related branches found
No related tags found
No related merge requests found
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
# limitations under the License. # limitations under the License.
# #
from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype
from pyspark import SparkContext, RDD from pyspark import SparkContext, RDD
import numpy as np
from pyspark.serializers import Serializer from pyspark.serializers import Serializer
import struct import struct
...@@ -47,13 +48,22 @@ def _deserialize_byte_array(shape, ba, offset): ...@@ -47,13 +48,22 @@ def _deserialize_byte_array(shape, ba, offset):
return ar.copy() return ar.copy()
def _serialize_double_vector(v): def _serialize_double_vector(v):
"""Serialize a double vector into a mutually understood format.""" """Serialize a double vector into a mutually understood format.
>>> x = array([1,2,3])
>>> y = _deserialize_double_vector(_serialize_double_vector(x))
>>> array_equal(y, array([1.0, 2.0, 3.0]))
True
"""
if type(v) != ndarray: if type(v) != ndarray:
raise TypeError("_serialize_double_vector called on a %s; " raise TypeError("_serialize_double_vector called on a %s; "
"wanted ndarray" % type(v)) "wanted ndarray" % type(v))
"""complex is only datatype that can't be converted to float64"""
if issubdtype(v.dtype, complex):
raise TypeError("_serialize_double_vector called on a %s; "
"wanted ndarray" % type(v))
if v.dtype != float64: if v.dtype != float64:
raise TypeError("_serialize_double_vector called on an ndarray of %s; " v = v.astype(float64)
"wanted ndarray of float64" % v.dtype)
if v.ndim != 1: if v.ndim != 1:
raise TypeError("_serialize_double_vector called on a %ddarray; " raise TypeError("_serialize_double_vector called on a %ddarray; "
"wanted a 1darray" % v.ndim) "wanted a 1darray" % v.ndim)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment