From 82870d507dfaeeaf315d6766ca1496205c6216d3 Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Mon, 8 Jun 2015 21:33:47 -0700
Subject: [PATCH] [SPARK-8168] [MLLIB] Add Python friendly constructor to
 PipelineModel

This makes the constructor callable in Python. dbtsai

Author: Xiangrui Meng <meng@databricks.com>

Closes #6709 from mengxr/SPARK-8168 and squashes the following commits:

f871de4 [Xiangrui Meng] Add Python friendly constructor to PipelineModel
---
 .../scala/org/apache/spark/ml/Pipeline.scala    |  8 ++++++++
 .../org/apache/spark/ml/PipelineSuite.scala     | 17 +++++++++++++++++
 2 files changed, 25 insertions(+)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 11a4722722..a9bd28df71 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -17,6 +17,9 @@
 
 package org.apache.spark.ml
 
+import java.{util => ju}
+
+import scala.collection.JavaConverters._
 import scala.collection.mutable.ListBuffer
 
 import org.apache.spark.Logging
@@ -175,6 +178,11 @@ class PipelineModel private[ml] (
     val stages: Array[Transformer])
   extends Model[PipelineModel] with Logging {
 
+  /** A Java/Python-friendly auxiliary constructor. */
+  private[ml] def this(uid: String, stages: ju.List[Transformer]) = {
+    this(uid, stages.asScala.toArray)
+  }
+
   override def validateParams(): Unit = {
     super.validateParams()
     stages.foreach(_.validateParams())
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 05bf58e63a..29394fefcb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.ml
 
+import scala.collection.JavaConverters._
+
 import org.mockito.Matchers.{any, eq => meq}
 import org.mockito.Mockito.when
 import org.scalatest.mock.MockitoSugar.mock
@@ -81,4 +83,19 @@ class PipelineSuite extends SparkFunSuite {
       pipeline.fit(dataset)
     }
   }
+
+  test("pipeline model constructors") {
+    val transform0 = mock[Transformer]
+    val model1 = mock[MyModel]
+
+    val stages = Array(transform0, model1)
+    val pipelineModel0 = new PipelineModel("pipeline0", stages)
+    assert(pipelineModel0.uid === "pipeline0")
+    assert(pipelineModel0.stages === stages)
+
+    val stagesAsList = stages.toList.asJava
+    val pipelineModel1 = new PipelineModel("pipeline1", stagesAsList)
+    assert(pipelineModel1.uid === "pipeline1")
+    assert(pipelineModel1.stages === stages)
+  }
 }
-- 
GitLab