diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index a87f8a6cde318bb11205217085b3ee26b7175605..c13b9a66c4e24cc47e26a8945cbbd5be58fff3be 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -75,8 +75,8 @@ class DecisionTreeModel @Since("1.0.0") (
    * @return JavaRDD of predictions for each of the given data points
    */
   @Since("1.2.0")
-  def predict(features: JavaRDD[Vector]): JavaRDD[Double] = {
-    predict(features.rdd)
+  def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = {
+    predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
   }
 
   /**
diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
index 8dd29061daaad1d857459cf31f72a78134eec0c8..60585d27277d1c211a0ab260e935a3919824e53a 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
@@ -28,6 +28,8 @@ import org.junit.Test;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.regression.LabeledPoint;
 import org.apache.spark.mllib.tree.configuration.Algo;
 import org.apache.spark.mllib.tree.configuration.Strategy;
@@ -95,6 +97,14 @@ public class JavaDecisionTreeSuite implements Serializable {
 
     DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);
 
+    // java compatibility test
+    JavaRDD<Double> predictions = model.predict(rdd.map(new Function<LabeledPoint, Vector>() {
+      @Override
+      public Vector call(LabeledPoint v1) {
+        return v1.features();
+      }
+    }));
+
     int numCorrect = validatePrediction(arr, model);
     Assert.assertTrue(numCorrect == rdd.count());
   }