Skip to content
Snippets Groups Projects
Commit 14bc5a7f authored by Yuhao Yang's avatar Yuhao Yang Committed by Xiangrui Meng
Browse files

[SPARK-16177][ML] model loading backward compatibility for ml.regression

## What changes were proposed in this pull request?
jira: https://issues.apache.org/jira/browse/SPARK-16177
model loading backward compatibility for ml.regression

## How was this patch tested?

existing ut and manual test for loading 1.6 models.

Author: Yuhao Yang <hhbyyh@gmail.com>

Closes #13879 from hhbyyh/regreComp.
parent 6a3c6276
No related branches found
No related tags found
No related merge requests found
......@@ -33,6 +33,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
......@@ -389,10 +390,10 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
.select("coefficients", "intercept", "scale").head()
val coefficients = data.getAs[Vector](0)
val intercept = data.getDouble(1)
val scale = data.getDouble(2)
val Row(coefficients: Vector, intercept: Double, scale: Double) =
MLUtils.convertVectorColumnsToML(data, "coefficients")
.select("coefficients", "intercept", "scale")
.head()
val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale)
DefaultParamsReader.getAndSetParams(model, metadata)
......
......@@ -39,6 +39,7 @@ import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
......@@ -500,9 +501,10 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.format("parquet").load(dataPath)
.select("intercept", "coefficients").head()
val intercept = data.getDouble(0)
val coefficients = data.getAs[Vector](1)
val Row(intercept: Double, coefficients: Vector) =
MLUtils.convertVectorColumnsToML(data, "coefficients")
.select("intercept", "coefficients")
.head()
val model = new LinearRegressionModel(metadata.uid, coefficients, intercept)
DefaultParamsReader.getAndSetParams(model, metadata)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment