From 6765ef98dff070768bbcd585d341ee7664fbe76c Mon Sep 17 00:00:00 2001
From: MechCoder <manojkumarsivaraj334@gmail.com>
Date: Wed, 17 Jun 2015 11:10:16 -0700
Subject: [PATCH] [SPARK-6390] [SQL] [MLlib] Port MatrixUDT to PySpark

MatrixUDT was recently coded in scala. This has been ported to PySpark

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #6354 from MechCoder/spark-6390 and squashes the following commits:

fc4dc1e [MechCoder] Better error message
c940a44 [MechCoder] Added test
aa9c391 [MechCoder] Add pyUDT to MatrixUDT
62a2a7d [MechCoder] [SPARK-6390] Port MatrixUDT to PySpark
---
 .../apache/spark/mllib/linalg/Matrices.scala  |  2 +
 python/pyspark/mllib/linalg.py                | 59 ++++++++++++++++++-
 python/pyspark/mllib/tests.py                 | 34 ++++++++++-
 python/pyspark/sql/dataframe.py               |  6 +-
 4 files changed, 97 insertions(+), 4 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 9584da8e3a..85e63b1382 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -197,6 +197,8 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
 
   override def typeName: String = "matrix"
 
+  override def pyUDT: String = "pyspark.mllib.linalg.MatrixUDT"
+
   private[spark] override def asNullable: MatrixUDT = this
 }
 
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 23d1a79ffe..e96c5ef87d 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -36,7 +36,7 @@ else:
 import numpy as np
 
 from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
-    IntegerType, ByteType
+    IntegerType, ByteType, BooleanType
 
 
 __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors',
@@ -163,6 +163,59 @@ class VectorUDT(UserDefinedType):
         return "vector"
 
 
+class MatrixUDT(UserDefinedType):
+    """
+    SQL user-defined type (UDT) for Matrix.
+    """
+
+    @classmethod
+    def sqlType(cls):
+        return StructType([
+            StructField("type", ByteType(), False),
+            StructField("numRows", IntegerType(), False),
+            StructField("numCols", IntegerType(), False),
+            StructField("colPtrs", ArrayType(IntegerType(), False), True),
+            StructField("rowIndices", ArrayType(IntegerType(), False), True),
+            StructField("values", ArrayType(DoubleType(), False), True),
+            StructField("isTransposed", BooleanType(), False)])
+
+    @classmethod
+    def module(cls):
+        return "pyspark.mllib.linalg"
+
+    @classmethod
+    def scalaUDT(cls):
+        return "org.apache.spark.mllib.linalg.MatrixUDT"
+
+    def serialize(self, obj):
+        if isinstance(obj, SparseMatrix):
+            colPtrs = [int(i) for i in obj.colPtrs]
+            rowIndices = [int(i) for i in obj.rowIndices]
+            values = [float(v) for v in obj.values]
+            return (0, obj.numRows, obj.numCols, colPtrs,
+                    rowIndices, values, bool(obj.isTransposed))
+        elif isinstance(obj, DenseMatrix):
+            values = [float(v) for v in obj.values]
+            return (1, obj.numRows, obj.numCols, None, None, values,
+                    bool(obj.isTransposed))
+        else:
+            raise TypeError("cannot serialize type %r" % (type(obj)))
+
+    def deserialize(self, datum):
+        assert len(datum) == 7, \
+            "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum)
+        tpe = datum[0]
+        if tpe == 0:
+            return SparseMatrix(*datum[1:])
+        elif tpe == 1:
+            return DenseMatrix(datum[1], datum[2], datum[5], datum[6])
+        else:
+            raise ValueError("do not recognize type %r" % tpe)
+
+    def simpleString(self):
+        return "matrix"
+
+
 class Vector(object):
 
     __UDT__ = VectorUDT()
@@ -781,10 +834,12 @@ class Vectors(object):
 
 
 class Matrix(object):
+
+    __UDT__ = MatrixUDT()
+
     """
     Represents a local matrix.
     """
-
     def __init__(self, numRows, numCols, isTransposed=False):
         self.numRows = numRows
         self.numCols = numCols
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 36a4c7a540..f4c997261e 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -39,7 +39,7 @@ else:
 from pyspark import SparkContext
 from pyspark.mllib.common import _to_java_object_rdd
 from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
-    DenseMatrix, SparseMatrix, Vectors, Matrices
+    DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
 from pyspark.mllib.regression import LabeledPoint
 from pyspark.mllib.random import RandomRDDs
 from pyspark.mllib.stat import Statistics
@@ -507,6 +507,38 @@ class VectorUDTTests(MLlibTestCase):
                 raise TypeError("expecting a vector but got %r of type %r" % (v, type(v)))
 
 
+class MatrixUDTTests(MLlibTestCase):
+
+    dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10])
+    dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True)
+    sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0])
+    sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True)
+    udt = MatrixUDT()
+
+    def test_json_schema(self):
+        self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+    def test_serialization(self):
+        for m in [self.dm1, self.dm2, self.sm1, self.sm2]:
+            self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m)))
+
+    def test_infer_schema(self):
+        sqlCtx = SQLContext(self.sc)
+        rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)])
+        df = rdd.toDF()
+        schema = df.schema
+        self.assertTrue(schema.fields[1].dataType, self.udt)
+        matrices = df.map(lambda x: x._2).collect()
+        self.assertEqual(len(matrices), 2)
+        for m in matrices:
+            if isinstance(m, DenseMatrix):
+                self.assertTrue(m, self.dm1)
+            elif isinstance(m, SparseMatrix):
+                self.assertTrue(m, self.sm1)
+            else:
+                raise ValueError("Expected a matrix but got type %r" % type(m))
+
+
 @unittest.skipIf(not _have_scipy, "SciPy not installed")
 class SciPyTests(MLlibTestCase):
 
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 9615e57649..152b87351d 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -194,7 +194,11 @@ class DataFrame(object):
         StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
         """
         if self._schema is None:
-            self._schema = _parse_datatype_json_string(self._jdf.schema().json())
+            try:
+                self._schema = _parse_datatype_json_string(self._jdf.schema().json())
+            except AttributeError as e:
+                raise Exception(
+                    "Unable to parse datatype from schema. %s" % e)
         return self._schema
 
     @since(1.3)
-- 
GitLab