Skip to content
Snippets Groups Projects
Unverified Commit 42777b1b authored by VinceShieh's avatar VinceShieh Committed by Sean Owen
Browse files

[SPARK-17462][MLLIB]use VersionUtils to parse Spark version strings


## What changes were proposed in this pull request?

Several places in MLlib use custom regexes or other approaches to parse Spark versions.
Those should be fixed to use the VersionUtils. This PR replaces custom regexes with
VersionUtils to get Spark version numbers.
## How was this patch tested?

Existing tests.

Signed-off-by: VinceShieh vincent.xieintel.com

Author: VinceShieh <vincent.xie@intel.com>

Closes #15055 from VinceShieh/SPARK-17462.

(cherry picked from commit de77c677)
Signed-off-by: default avatarSean Owen <sowen@cloudera.com>
parent 4fcecb4c
No related branches found
No related tags found
No related merge requests found
......@@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.util.VersionUtils.majorVersion
/**
* Common params for KMeans and KMeansModel
......@@ -232,10 +233,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val versionRegex = "([0-9]+)\\.(.+)".r
val versionRegex(major, _) = metadata.sparkVersion
val clusterCenters = if (major.toInt >= 2) {
val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
} else {
......
......@@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.util.VersionUtils.majorVersion
/**
* Params for [[PCA]] and [[PCAModel]].
......@@ -204,11 +205,8 @@ object PCAModel extends MLReadable[PCAModel] {
override def load(path: String): PCAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val versionRegex = "([0-9]+)\\.(.+)".r
val versionRegex(major, _) = metadata.sparkVersion
val dataPath = new Path(path, "data").toString
val model = if (major.toInt >= 2) {
val model = if (majorVersion(metadata.sparkVersion) >= 2) {
val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
sparkSession.read.parquet(dataPath)
.select("pc", "explainedVariance")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment