diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index a0d481b294ac7cdd13fc3321eb85baa07d263d4c..26505b4cc1501a15510d32cac2b3da3547982e90 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.util.VersionUtils.majorVersion
 
 /**
  * Common params for KMeans and KMeansModel
@@ -232,10 +233,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
 
-      val versionRegex = "([0-9]+)\\.(.+)".r
-      val versionRegex(major, _) = metadata.sparkVersion
-
-      val clusterCenters = if (major.toInt >= 2) {
+      val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
         val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
         data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
       } else {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 444006fe1edb6c9e323cf7dfcac37dfcff2b38a0..1e49352b8517e043556f633637c62a5c097cbe61 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.util.VersionUtils.majorVersion
 
 /**
  * Params for [[PCA]] and [[PCAModel]].
@@ -204,11 +205,8 @@ object PCAModel extends MLReadable[PCAModel] {
     override def load(path: String): PCAModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
 
-      val versionRegex = "([0-9]+)\\.(.+)".r
-      val versionRegex(major, _) = metadata.sparkVersion
-
       val dataPath = new Path(path, "data").toString
-      val model = if (major.toInt >= 2) {
+      val model = if (majorVersion(metadata.sparkVersion) >= 2) {
         val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
           sparkSession.read.parquet(dataPath)
             .select("pc", "explainedVariance")