From 454a00df2a43176cb774cad7277934a775618db1 Mon Sep 17 00:00:00 2001
From: Xusen Yin <yinxusen@gmail.com>
Date: Sun, 20 Mar 2016 15:34:34 -0700
Subject: [PATCH] [SPARK-13993][PYSPARK] Add pyspark Rformula/RforumlaModel
 save/load

## What changes were proposed in this pull request?

https://issues.apache.org/jira/browse/SPARK-13993

## How was this patch tested?

doctest

Author: Xusen Yin <yinxusen@gmail.com>

Closes #11807 from yinxusen/SPARK-13993.
---
 python/pyspark/ml/feature.py | 30 +++++++++++++++++++++++++++---
 1 file changed, 27 insertions(+), 3 deletions(-)

diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 5025493c42..3182faac0d 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2360,7 +2360,7 @@ class PCAModel(JavaModel, MLReadable, MLWritable):
 
 
 @inherit_doc
-class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
+class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritable):
     """
     .. note:: Experimental
 
@@ -2376,7 +2376,8 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
     ...     (0.0, 0.0, "a")
     ... ], ["y", "x", "s"])
     >>> rf = RFormula(formula="y ~ x + s")
-    >>> rf.fit(df).transform(df).show()
+    >>> model = rf.fit(df)
+    >>> model.transform(df).show()
     +---+---+---+---------+-----+
     |  y|  x|  s| features|label|
     +---+---+---+---------+-----+
@@ -2394,6 +2395,29 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
     |0.0|0.0|  a|   [0.0]|  0.0|
     +---+---+---+--------+-----+
     ...
+    >>> rFormulaPath = temp_path + "/rFormula"
+    >>> rf.save(rFormulaPath)
+    >>> loadedRF = RFormula.load(rFormulaPath)
+    >>> loadedRF.getFormula() == rf.getFormula()
+    True
+    >>> loadedRF.getFeaturesCol() == rf.getFeaturesCol()
+    True
+    >>> loadedRF.getLabelCol() == rf.getLabelCol()
+    True
+    >>> modelPath = temp_path + "/rFormulaModel"
+    >>> model.save(modelPath)
+    >>> loadedModel = RFormulaModel.load(modelPath)
+    >>> loadedModel.uid == model.uid
+    True
+    >>> loadedModel.transform(df).show()
+    +---+---+---+---------+-----+
+    |  y|  x|  s| features|label|
+    +---+---+---+---------+-----+
+    |1.0|1.0|  a|[1.0,1.0]|  1.0|
+    |0.0|2.0|  b|[2.0,0.0]|  0.0|
+    |0.0|0.0|  a|[0.0,1.0]|  0.0|
+    +---+---+---+---------+-----+
+    ...
 
     .. versionadded:: 1.5.0
     """
@@ -2439,7 +2463,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol):
         return RFormulaModel(java_model)
 
 
-class RFormulaModel(JavaModel):
+class RFormulaModel(JavaModel, MLReadable, MLWritable):
     """
     .. note:: Experimental
 
-- 
GitLab