diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 850d775db0b72914312a104eea36894cde6b3e30..d51b80e16c13a3bd2131663b0f24199b7cdaaaf4 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -278,7 +278,8 @@ class GBTParams(TreeEnsembleParams):
 @inherit_doc
 class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
                              HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams,
-                             TreeClassifierParams, HasCheckpointInterval, HasSeed):
+                             TreeClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable,
+                             JavaMLReadable):
     """
     `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
     learning algorithm for classification.
@@ -313,6 +314,17 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
     >>> model.transform(test1).head().prediction
     1.0
 
+    >>> dtc_path = temp_path + "/dtc"
+    >>> dt.save(dtc_path)
+    >>> dt2 = DecisionTreeClassifier.load(dtc_path)
+    >>> dt2.getMaxDepth()
+    2
+    >>> model_path = temp_path + "/dtc_model"
+    >>> model.save(model_path)
+    >>> model2 = DecisionTreeClassificationModel.load(model_path)
+    >>> model.featureImportances == model2.featureImportances
+    True
+
     .. versionadded:: 1.4.0
     """
 
@@ -361,7 +373,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
 
 
 @inherit_doc
-class DecisionTreeClassificationModel(DecisionTreeModel):
+class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
     """
     Model fitted by DecisionTreeClassifier.
 
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 59d4fe3cf4eddd7195db51e020d062b379608dbd..37648549dee207d86e96e0aaa6f1f6ae9a4fb2b4 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -389,7 +389,7 @@ class GBTParams(TreeEnsembleParams):
 @inherit_doc
 class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
                             DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval,
-                            HasSeed):
+                            HasSeed, JavaMLWritable, JavaMLReadable):
     """
     `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
     learning algorithm for regression.
@@ -413,6 +413,18 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
     >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
     >>> model.transform(test1).head().prediction
     1.0
+    >>> dtr_path = temp_path + "/dtr"
+    >>> dt.save(dtr_path)
+    >>> dt2 = DecisionTreeRegressor.load(dtr_path)
+    >>> dt2.getMaxDepth()
+    2
+    >>> model_path = temp_path + "/dtr_model"
+    >>> model.save(model_path)
+    >>> model2 = DecisionTreeRegressionModel.load(model_path)
+    >>> model.numNodes == model2.numNodes
+    True
+    >>> model.depth == model2.depth
+    True
 
     .. versionadded:: 1.4.0
     """
@@ -498,7 +510,7 @@ class TreeEnsembleModels(JavaModel):
 
 
 @inherit_doc
-class DecisionTreeRegressionModel(DecisionTreeModel):
+class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReadable):
     """
     Model fitted by DecisionTreeRegressor.
 
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 2fa5da7738c1bf1411b2be42aaec9f192a769d2c..224232ed7f62f1ddb6a75a267aa1fba5e30ca48b 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -42,13 +42,13 @@ import tempfile
 import numpy as np
 
 from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
-from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier
 from pyspark.ml.clustering import KMeans
 from pyspark.ml.evaluation import RegressionEvaluator
 from pyspark.ml.feature import *
 from pyspark.ml.param import Param, Params, TypeConverters
 from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
-from pyspark.ml.regression import LinearRegression
+from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
 from pyspark.ml.tuning import *
 from pyspark.ml.util import keyword_only
 from pyspark.ml.wrapper import JavaWrapper
@@ -655,6 +655,42 @@ class PersistenceTest(PySparkTestCase):
             except OSError:
                 pass
 
+    def test_decisiontree_classifier(self):
+        dt = DecisionTreeClassifier(maxDepth=1)
+        path = tempfile.mkdtemp()
+        dtc_path = path + "/dtc"
+        dt.save(dtc_path)
+        dt2 = DecisionTreeClassifier.load(dtc_path)
+        self.assertEqual(dt2.uid, dt2.maxDepth.parent,
+                         "Loaded DecisionTreeClassifier instance uid (%s) "
+                         "did not match Param's uid (%s)"
+                         % (dt2.uid, dt2.maxDepth.parent))
+        self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth],
+                         "Loaded DecisionTreeClassifier instance default params did not match " +
+                         "original defaults")
+        try:
+            rmtree(path)
+        except OSError:
+            pass
+
+    def test_decisiontree_regressor(self):
+        dt = DecisionTreeRegressor(maxDepth=1)
+        path = tempfile.mkdtemp()
+        dtr_path = path + "/dtr"
+        dt.save(dtr_path)
+        dt2 = DecisionTreeClassifier.load(dtr_path)
+        self.assertEqual(dt2.uid, dt2.maxDepth.parent,
+                         "Loaded DecisionTreeRegressor instance uid (%s) "
+                         "did not match Param's uid (%s)"
+                         % (dt2.uid, dt2.maxDepth.parent))
+        self.assertEqual(dt._defaultParamMap[dt.maxDepth], dt2._defaultParamMap[dt2.maxDepth],
+                         "Loaded DecisionTreeRegressor instance default params did not match " +
+                         "original defaults")
+        try:
+            rmtree(path)
+        except OSError:
+            pass
+
 
 class HasThrowableProperty(Params):