Skip to content
Snippets Groups Projects
Commit cc6778ee authored by Yuhao Yang's avatar Yuhao Yang Committed by Xiangrui Meng
Browse files

[SPARK-16133][ML] model loading backward compatibility for ml.feature

## What changes were proposed in this pull request?

model loading backward compatibility for ml.feature,

## How was this patch tested?

existing ut and manual test for loading 1.6 models.

Author: Yuhao Yang <yuhao.yang@intel.com>
Author: Yuhao Yang <hhbyyh@gmail.com>

Closes #13844 from hhbyyh/featureComp.
parent 4a40d43b
No related branches found
No related tags found
No related merge requests found
......@@ -27,6 +27,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
......@@ -180,9 +181,9 @@ object IDFModel extends MLReadable[IDFModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
val Row(idf: Vector) = MLUtils.convertVectorColumnsToML(data, "idf")
.select("idf")
.head()
val idf = data.getAs[Vector](0)
val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf)))
DefaultParamsReader.getAndSetParams(model, metadata)
model
......
......@@ -28,6 +28,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
......@@ -232,9 +233,11 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] {
override def load(path: String): MinMaxScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(originalMin: Vector, originalMax: Vector) = sparkSession.read.parquet(dataPath)
.select("originalMin", "originalMax")
.head()
val data = sparkSession.read.parquet(dataPath)
val Row(originalMin: Vector, originalMax: Vector) =
MLUtils.convertVectorColumnsToML(data, "originalMin", "originalMax")
.select("originalMin", "originalMax")
.head()
val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax)
DefaultParamsReader.getAndSetParams(model, metadata)
model
......
......@@ -28,6 +28,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
......@@ -211,7 +212,8 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
override def load(path: String): StandardScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(std: Vector, mean: Vector) = sparkSession.read.parquet(dataPath)
val data = sparkSession.read.parquet(dataPath)
val Row(std: Vector, mean: Vector) = MLUtils.convertVectorColumnsToML(data, "std", "mean")
.select("std", "mean")
.head()
val model = new StandardScalerModel(metadata.uid, std, mean)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment