diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 4307ad02a0ebd966f59b9500f81f500e05ff9a7e..a78e3b49fbcfc143431bb5a18c1a421de0541fcd 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -21,7 +21,7 @@ if sys.version > '3':
     basestring = str
 
 from pyspark import since, keyword_only, SparkContext
-from pyspark.ml import Estimator, Model, Transformer
+from pyspark.ml.base import Estimator, Model, Transformer
 from pyspark.ml.param import Param, Params
 from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable
 from pyspark.ml.wrapper import JavaParams