diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index aa78149699a277a37c299090ddbd236bbb007a21..df2a9c0dd5094f448f0efb5e1d5388166adf9b3b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -39,6 +39,9 @@ import org.apache.spark.sql.{Row, SparkSession}
 class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector])
   extends Saveable with Serializable with PMMLExportable {
 
+  private val clusterCentersWithNorm =
+    if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_))
+
   /**
    * A Java-friendly constructor that takes an Iterable of Vectors.
    */
@@ -49,7 +52,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
    * Total number of clusters.
    */
   @Since("0.8.0")
-  def k: Int = clusterCenters.length
+  def k: Int = clusterCentersWithNorm.length
 
   /**
    * Returns the cluster index that a given point belongs to.
@@ -64,8 +67,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
    */
   @Since("1.0.0")
   def predict(points: RDD[Vector]): RDD[Int] = {
-    val centersWithNorm = clusterCentersWithNorm
-    val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
+    val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm)
     points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
   }
 
@@ -82,13 +84,10 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
    */
   @Since("0.8.0")
   def computeCost(data: RDD[Vector]): Double = {
-    val centersWithNorm = clusterCentersWithNorm
-    val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
+    val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm)
     data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum()
   }
 
-  private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
-    clusterCenters.map(new VectorWithNorm(_))
 
   @Since("1.4.0")
   override def save(sc: SparkContext, path: String): Unit = {
@@ -127,8 +126,8 @@ object KMeansModel extends Loader[KMeansModel] {
       val metadata = compact(render(
         ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
       sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
-      val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
-        Cluster(id, point)
+      val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) =>
+        Cluster(id, p.vector)
       }
       spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
     }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 85c37c438d93a1e5e5e304e9cb1213beb964d05b..3ca75e8cdb97a328fab8dd72fb115988e5ea8702 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -145,7 +145,7 @@ class StreamingKMeansModel @Since("1.2.0") (
       }
     }
 
-    this
+    new StreamingKMeansModel(clusterCenters, clusterWeights)
   }
 }