diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index c9e6f1dec6bf85f40d77f54aa6d028984deea102..48daa87e82d13568509f38c505580dfd86d28129 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -346,7 +346,7 @@ class GaussianMixture(object):
             if initialModel.k != k:
                 raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s"
                                 % (initialModel.k, k))
-            initialModelWeights = initialModel.weights
+            initialModelWeights = list(initialModel.weights)
             initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
             initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
         java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 6ed03e35828edbaa84cdb2f65aad3dc25d539faa..3436a28b2974ff8f6bfaccc100d1f0e27976b6b1 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -475,6 +475,18 @@ class ListTests(MLlibTestCase):
         for c1, c2 in zip(clusters1.weights, clusters2.weights):
             self.assertEqual(round(c1, 7), round(c2, 7))
 
+    def test_gmm_with_initial_model(self):
+        from pyspark.mllib.clustering import GaussianMixture
+        data = self.sc.parallelize([
+            (-10, -5), (-9, -4), (10, 5), (9, 4)
+        ])
+
+        gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001,
+                                     maxIterations=10, seed=63)
+        gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001,
+                                     maxIterations=10, seed=63, initialModel=gmm1)
+        self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0)
+
     def test_classification(self):
         from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
         from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\