diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 69ce7f50709a13087f389ffc49914610df338330..21e55938fa7aa6d9f5664c233053c9944d2dc294 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -336,7 +336,8 @@ private[python] class PythonMLLibAPI extends Serializable {
       initializationMode: String,
       seed: java.lang.Long,
       initializationSteps: Int,
-      epsilon: Double): KMeansModel = {
+      epsilon: Double,
+      initialModel: java.util.ArrayList[Vector]): KMeansModel = {
     val kMeansAlg = new KMeans()
       .setK(k)
       .setMaxIterations(maxIterations)
@@ -346,6 +347,7 @@ private[python] class PythonMLLibAPI extends Serializable {
       .setEpsilon(epsilon)
 
     if (seed != null) kMeansAlg.setSeed(seed)
+    if (!initialModel.isEmpty()) kMeansAlg.setInitialModel(new KMeansModel(initialModel))
 
     try {
       kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 900ade248c3868537ce76046a8031b158ab01d37..6964a45db249396e5ecabd00af5a2c9cb8137046 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -90,6 +90,12 @@ class KMeansModel(Saveable, Loader):
     ...     rmtree(path)
     ... except OSError:
     ...     pass
+
+    >>> data = array([-383.1,-382.9, 28.7,31.2, 366.2,367.3]).reshape(3, 2)
+    >>> model = KMeans.train(sc.parallelize(data), 3, maxIterations=0,
+    ...     initialModel = KMeansModel([(-1000.0,-1000.0),(5.0,5.0),(1000.0,1000.0)]))
+    >>> model.clusterCenters
+    [array([-1000., -1000.]), array([ 5.,  5.]), array([ 1000.,  1000.])]
     """
 
     def __init__(self, centers):
@@ -144,10 +150,17 @@ class KMeans(object):
 
     @classmethod
     def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||",
-              seed=None, initializationSteps=5, epsilon=1e-4):
+              seed=None, initializationSteps=5, epsilon=1e-4, initialModel=None):
         """Train a k-means clustering model."""
+        clusterInitialModel = []
+        if initialModel is not None:
+            if not isinstance(initialModel, KMeansModel):
+                raise Exception("initialModel is of "+str(type(initialModel))+". It needs "
+                                "to be of <type 'KMeansModel'>")
+            clusterInitialModel = [_convert_to_vector(c) for c in initialModel.clusterCenters]
         model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
-                              runs, initializationMode, seed, initializationSteps, epsilon)
+                              runs, initializationMode, seed, initializationSteps, epsilon,
+                              clusterInitialModel)
         centers = callJavaFunc(rdd.context, model.clusterCenters)
         return KMeansModel([c.toArray() for c in centers])