From c17b1abff8f8c6d24cb0cf4ff4f8c14a780c64b0 Mon Sep 17 00:00:00 2001
From: Yuhao Yang <yuhao.yang@intel.com>
Date: Mon, 27 Jun 2016 12:27:39 -0700
Subject: [PATCH] [SPARK-16187][ML] Implement util method for ML Matrix
 conversion in scala/java

## What changes were proposed in this pull request?
jira: https://issues.apache.org/jira/browse/SPARK-16187
This is to provide conversion utils between old/new vector columns in a DataFrame. So users can use it to migrate their datasets and pipelines manually.

## How was this patch tested?

java and scala ut

Author: Yuhao Yang <yuhao.yang@intel.com>

Closes #13888 from hhbyyh/matComp.
---
 .../apache/spark/ml/linalg/MatrixUDT.scala    |   2 +-
 .../org/apache/spark/mllib/util/MLUtils.scala | 107 +++++++++++++++++-
 .../spark/mllib/util/JavaMLUtilsSuite.java    |  29 ++++-
 .../spark/mllib/util/MLUtilsSuite.scala       |  56 ++++++++-
 4 files changed, 187 insertions(+), 7 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala
index 521a216c67..a1e53662f0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.types._
  * User-defined type for [[Matrix]] in [[mllib-local]] which allows easy interaction with SQL
  * via [[org.apache.spark.sql.Dataset]].
  */
-private[ml] class MatrixUDT extends UserDefinedType[Matrix] {
+private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
 
   override def sqlType: StructType = {
     // type: 0 = sparse, 1 = dense
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 7d5bdffc42..e96c2bc6ed 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Since
 import org.apache.spark.internal.Logging
-import org.apache.spark.ml.linalg.{VectorUDT => MLVectorUDT}
+import org.apache.spark.ml.linalg.{MatrixUDT => MLMatrixUDT, VectorUDT => MLVectorUDT}
 import org.apache.spark.mllib.linalg._
 import org.apache.spark.mllib.linalg.BLAS.dot
 import org.apache.spark.mllib.regression.LabeledPoint
@@ -309,8 +309,8 @@ object MLUtils extends Logging {
   }
 
   /**
-   * Converts vector columns in an input Dataset to the [[org.apache.spark.ml.linalg.Vector]] type
-   * from the new [[org.apache.spark.mllib.linalg.Vector]] type under the `spark.ml` package.
+   * Converts vector columns in an input Dataset to the [[org.apache.spark.mllib.linalg.Vector]]
+   * type from the new [[org.apache.spark.ml.linalg.Vector]] type under the `spark.ml` package.
    * @param dataset input dataset
    * @param cols a list of vector columns to be converted. Old vector columns will be ignored. If
    *             unspecified, all new vector columns will be converted except nested ones.
@@ -360,6 +360,107 @@ object MLUtils extends Logging {
     dataset.select(exprs: _*)
   }
 
+  /**
+   * Converts Matrix columns in an input Dataset from the [[org.apache.spark.mllib.linalg.Matrix]]
+   * type to the new [[org.apache.spark.ml.linalg.Matrix]] type under the `spark.ml` package.
+   * @param dataset input dataset
+   * @param cols a list of matrix columns to be converted. New matrix columns will be ignored. If
+   *             unspecified, all old matrix columns will be converted except nested ones.
+   * @return the input [[DataFrame]] with old matrix columns converted to the new matrix type
+   */
+  @Since("2.0.0")
+  @varargs
+  def convertMatrixColumnsToML(dataset: Dataset[_], cols: String*): DataFrame = {
+    val schema = dataset.schema
+    val colSet = if (cols.nonEmpty) {
+      cols.flatMap { c =>
+        val dataType = schema(c).dataType
+        if (dataType.getClass == classOf[MatrixUDT]) {
+          Some(c)
+        } else {
+          // ignore new matrix columns and raise an exception on other column types
+          require(dataType.getClass == classOf[MLMatrixUDT],
+            s"Column $c must be old Matrix type to be converted to new type but got $dataType.")
+          None
+        }
+      }.toSet
+    } else {
+      schema.fields
+        .filter(_.dataType.getClass == classOf[MatrixUDT])
+        .map(_.name)
+        .toSet
+    }
+
+    if (colSet.isEmpty) {
+      return dataset.toDF()
+    }
+
+    logWarning("Matrix column conversion has serialization overhead. " +
+      "Please migrate your datasets and workflows to use the spark.ml package.")
+
+    val convertToML = udf { v: Matrix => v.asML }
+    val exprs = schema.fields.map { field =>
+      val c = field.name
+      if (colSet.contains(c)) {
+        convertToML(col(c)).as(c, field.metadata)
+      } else {
+        col(c)
+      }
+    }
+    dataset.select(exprs: _*)
+  }
+
+  /**
+   * Converts matrix columns in an input Dataset to the [[org.apache.spark.mllib.linalg.Matrix]]
+   * type from the new [[org.apache.spark.ml.linalg.Matrix]] type under the `spark.ml` package.
+   * @param dataset input dataset
+   * @param cols a list of matrix columns to be converted. Old matrix columns will be ignored. If
+   *             unspecified, all new matrix columns will be converted except nested ones.
+   * @return the input [[DataFrame]] with new matrix columns converted to the old matrix type
+   */
+  @Since("2.0.0")
+  @varargs
+  def convertMatrixColumnsFromML(dataset: Dataset[_], cols: String*): DataFrame = {
+    val schema = dataset.schema
+    val colSet = if (cols.nonEmpty) {
+      cols.flatMap { c =>
+        val dataType = schema(c).dataType
+        if (dataType.getClass == classOf[MLMatrixUDT]) {
+          Some(c)
+        } else {
+          // ignore old matrix columns and raise an exception on other column types
+          require(dataType.getClass == classOf[MatrixUDT],
+            s"Column $c must be new Matrix type to be converted to old type but got $dataType.")
+          None
+        }
+      }.toSet
+    } else {
+      schema.fields
+        .filter(_.dataType.getClass == classOf[MLMatrixUDT])
+        .map(_.name)
+        .toSet
+    }
+
+    if (colSet.isEmpty) {
+      return dataset.toDF()
+    }
+
+    logWarning("Matrix column conversion has serialization overhead. " +
+      "Please migrate your datasets and workflows to use the spark.ml package.")
+
+    val convertFromML = udf { Matrices.fromML _ }
+    val exprs = schema.fields.map { field =>
+      val c = field.name
+      if (colSet.contains(c)) {
+        convertFromML(col(c)).as(c, field.metadata)
+      } else {
+        col(c)
+      }
+    }
+    dataset.select(exprs: _*)
+  }
+
+
   /**
    * Returns the squared Euclidean distance between two vectors. The following formula will be used
    * if it does not introduce too much numerical error:
diff --git a/mllib/src/test/java/org/apache/spark/mllib/util/JavaMLUtilsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/util/JavaMLUtilsSuite.java
index 2fa0bd2546..e271a0a77c 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/util/JavaMLUtilsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/util/JavaMLUtilsSuite.java
@@ -17,18 +17,22 @@
 
 package org.apache.spark.mllib.util;
 
+import java.util.Arrays;
 import java.util.Collections;
 
 import org.junit.Assert;
 import org.junit.Test;
 
 import org.apache.spark.SharedSparkSession;
-import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.linalg.*;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
 
 public class JavaMLUtilsSuite extends SharedSparkSession {
 
@@ -46,4 +50,25 @@ public class JavaMLUtilsSuite extends SharedSparkSession {
     Row old1 = MLUtils.convertVectorColumnsFromML(newDataset1).first();
     Assert.assertEquals(RowFactory.create(1.0, x), old1);
   }
+
+  @Test
+  public void testConvertMatrixColumnsToAndFromML() {
+    Matrix x = Matrices.dense(2, 1, new double[]{1.0, 2.0});
+    StructType schema = new StructType(new StructField[]{
+      new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
+      new StructField("features", new MatrixUDT(), false, Metadata.empty())
+    });
+    Dataset<Row> dataset = spark.createDataFrame(
+      Arrays.asList(
+        RowFactory.create(1.0, x)),
+      schema);
+
+    Dataset<Row> newDataset1 = MLUtils.convertMatrixColumnsToML(dataset);
+    Row new1 = newDataset1.first();
+    Assert.assertEquals(RowFactory.create(1.0, x.asML()), new1);
+    Row new2 = MLUtils.convertMatrixColumnsToML(dataset, "features").first();
+    Assert.assertEquals(new1, new2);
+    Row old1 = MLUtils.convertMatrixColumnsFromML(newDataset1).first();
+    Assert.assertEquals(RowFactory.create(1.0, x), old1);
+  }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 3801bd127a..6aa93c9076 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -26,7 +26,7 @@ import breeze.linalg.{squaredDistance => breezeSquaredDistance}
 import com.google.common.io.Files
 
 import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.MLUtils._
 import org.apache.spark.mllib.util.TestingUtils._
@@ -301,4 +301,58 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
       convertVectorColumnsFromML(df, "p._2")
     }
   }
+
+  test("convertMatrixColumnsToML") {
+    val x = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0))
+    val metadata = new MetadataBuilder().putLong("numFeatures", 2L).build()
+    val y = Matrices.dense(2, 1, Array(0.2, 1.3))
+    val z = Matrices.ones(1, 1)
+    val p = (5.0, z)
+    val w = Matrices.dense(1, 1, Array(4.5)).asML
+    val df = spark.createDataFrame(Seq(
+      (0, x, y, p, w)
+    )).toDF("id", "x", "y", "p", "w")
+      .withColumn("x", col("x"), metadata)
+    val newDF1 = convertMatrixColumnsToML(df)
+    assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.")
+    val new1 = newDF1.first()
+    assert(new1 === Row(0, x.asML, y.asML, Row(5.0, z), w))
+    val new2 = convertMatrixColumnsToML(df, "x", "y").first()
+    assert(new2 === new1)
+    val new3 = convertMatrixColumnsToML(df, "y", "w").first()
+    assert(new3 === Row(0, x, y.asML, Row(5.0, z), w))
+    intercept[IllegalArgumentException] {
+      convertMatrixColumnsToML(df, "p")
+    }
+    intercept[IllegalArgumentException] {
+      convertMatrixColumnsToML(df, "p._2")
+    }
+  }
+
+  test("convertMatrixColumnsFromML") {
+    val x = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0)).asML
+    val metadata = new MetadataBuilder().putLong("numFeatures", 2L).build()
+    val y = Matrices.dense(2, 1, Array(0.2, 1.3)).asML
+    val z = Matrices.ones(1, 1).asML
+    val p = (5.0, z)
+    val w = Matrices.dense(1, 1, Array(4.5))
+    val df = spark.createDataFrame(Seq(
+      (0, x, y, p, w)
+    )).toDF("id", "x", "y", "p", "w")
+      .withColumn("x", col("x"), metadata)
+    val newDF1 = convertMatrixColumnsFromML(df)
+    assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.")
+    val new1 = newDF1.first()
+    assert(new1 === Row(0, Matrices.fromML(x), Matrices.fromML(y), Row(5.0, z), w))
+    val new2 = convertMatrixColumnsFromML(df, "x", "y").first()
+    assert(new2 === new1)
+    val new3 = convertMatrixColumnsFromML(df, "y", "w").first()
+    assert(new3 === Row(0, x, Matrices.fromML(y), Row(5.0, z), w))
+    intercept[IllegalArgumentException] {
+      convertMatrixColumnsFromML(df, "p")
+    }
+    intercept[IllegalArgumentException] {
+      convertMatrixColumnsFromML(df, "p._2")
+    }
+  }
 }
-- 
GitLab