Skip to content
Snippets Groups Projects
Commit b30a2dc7 authored by zlpmichelle's avatar zlpmichelle Committed by Yanbo Liang
Browse files

[SPARK-16241][ML] model loading backward compatibility for ml NaiveBayes

## What changes were proposed in this pull request?

model loading backward compatibility for ml NaiveBayes

## How was this patch tested?

existing ut and manual test for loading models saved by Spark 1.6.

Author: zlpmichelle <zlpmichelle@gmail.com>

Closes #13940 from zlpmichelle/naivebayes.
parent 2c3d9613
No related branches found
No related tags found
No related merge requests found
...@@ -28,8 +28,9 @@ import org.apache.spark.ml.util._ ...@@ -28,8 +28,9 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes} import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel} import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset import org.apache.spark.sql.{Dataset, Row}
/** /**
* Params for Naive Bayes Classifiers. * Params for Naive Bayes Classifiers.
...@@ -275,9 +276,11 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] { ...@@ -275,9 +276,11 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head() val data = sparkSession.read.parquet(dataPath)
val pi = data.getAs[Vector](0) val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")
val theta = data.getAs[Matrix](1) val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
.select("pi", "theta")
.head()
val model = new NaiveBayesModel(metadata.uid, pi, theta) val model = new NaiveBayesModel(metadata.uid, pi, theta)
DefaultParamsReader.getAndSetParams(model, metadata) DefaultParamsReader.getAndSetParams(model, metadata)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment