Skip to content
Snippets Groups Projects
Commit 8e491af5 authored by Zheng RuiFeng's avatar Zheng RuiFeng Committed by Yanbo Liang
Browse files

[SPARK-14077][ML][FOLLOW-UP] Revert change for NB Model's Load to maintain...

[SPARK-14077][ML][FOLLOW-UP] Revert change for NB Model's Load to maintain compatibility with the model stored before 2.0

## What changes were proposed in this pull request?
Revert change for NB Model's Load to maintain compatibility with the model stored before 2.0

## How was this patch tested?
local build

Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes #15313 from zhengruifeng/revert_save_load.
parent 1fad5596
No related branches found
No related tags found
No related merge requests found
......@@ -25,7 +25,8 @@ import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.HasWeightCol
import org.apache.spark.ml.util._
import org.apache.spark.sql.Dataset
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType
......@@ -362,9 +363,11 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head()
val pi = data.getAs[Vector](0)
val theta = data.getAs[Matrix](1)
val data = sparkSession.read.parquet(dataPath)
val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")
val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
.select("pi", "theta")
.head()
val model = new NaiveBayesModel(metadata.uid, pi, theta)
DefaultParamsReader.getAndSetParams(model, metadata)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment