Skip to content
Snippets Groups Projects
Unverified Commit de77c677 authored by VinceShieh's avatar VinceShieh Committed by Sean Owen
Browse files

[SPARK-17462][MLLIB]use VersionUtils to parse Spark version strings

## What changes were proposed in this pull request?

Several places in MLlib use custom regexes or other approaches to parse Spark versions.
Those should be fixed to use the VersionUtils. This PR replaces custom regexes with
VersionUtils to get Spark version numbers.
## How was this patch tested?

Existing tests.

Signed-off-by: VinceShieh vincent.xieintel.com

Author: VinceShieh <vincent.xie@intel.com>

Closes #15055 from VinceShieh/SPARK-17462.
parent 49b6f456
No related branches found
No related tags found
No related merge requests found
......@@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.util.VersionUtils.majorVersion
/**
* Common params for KMeans and KMeansModel
......@@ -232,10 +233,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val versionRegex = "([0-9]+)\\.(.+)".r
val versionRegex(major, _) = metadata.sparkVersion
val clusterCenters = if (major.toInt >= 2) {
val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
} else {
......
......@@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.util.VersionUtils.majorVersion
/**
* Params for [[PCA]] and [[PCAModel]].
......@@ -204,11 +205,8 @@ object PCAModel extends MLReadable[PCAModel] {
override def load(path: String): PCAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val versionRegex = "([0-9]+)\\.(.+)".r
val versionRegex(major, _) = metadata.sparkVersion
val dataPath = new Path(path, "data").toString
val model = if (major.toInt >= 2) {
val model = if (majorVersion(metadata.sparkVersion) >= 2) {
val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
sparkSession.read.parquet(dataPath)
.select("pc", "explainedVariance")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment