diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index fc0693f67cc2e682080eb9c8ecd65c0332542e7e..bc19bd6df894fa42189baf6401bc01bec250984b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
 import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
@@ -43,7 +43,7 @@ import org.apache.spark.sql.types.DoubleType
  */
 @Experimental
 final class RandomForestClassifier(override val uid: String)
-  extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel]
+  extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
   with RandomForestParams with TreeClassifierParams {
 
   def this() = this(Identifiable.randomUID("rfc"))
@@ -98,7 +98,7 @@ final class RandomForestClassifier(override val uid: String)
     val trees =
       RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
         .map(_.asInstanceOf[DecisionTreeClassificationModel])
-    new RandomForestClassificationModel(trees)
+    new RandomForestClassificationModel(trees, numClasses)
   }
 
   override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
@@ -125,8 +125,9 @@ object RandomForestClassifier {
 @Experimental
 final class RandomForestClassificationModel private[ml] (
     override val uid: String,
-    private val _trees: Array[DecisionTreeClassificationModel])
-  extends PredictionModel[Vector, RandomForestClassificationModel]
+    private val _trees: Array[DecisionTreeClassificationModel],
+    override val numClasses: Int)
+  extends ClassificationModel[Vector, RandomForestClassificationModel]
   with TreeEnsembleModel with Serializable {
 
   require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
@@ -135,8 +136,8 @@ final class RandomForestClassificationModel private[ml] (
    * Construct a random forest classification model, with all trees weighted equally.
    * @param trees  Component trees
    */
-  def this(trees: Array[DecisionTreeClassificationModel]) =
-    this(Identifiable.randomUID("rfc"), trees)
+  def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) =
+    this(Identifiable.randomUID("rfc"), trees, numClasses)
 
   override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
 
@@ -153,20 +154,20 @@ final class RandomForestClassificationModel private[ml] (
     dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
   }
 
-  override protected def predict(features: Vector): Double = {
+  override protected def predictRaw(features: Vector): Vector = {
     // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
     // Classifies using majority votes.
     // Ignore the weights since all are 1.0 for now.
-    val votes = mutable.Map.empty[Int, Double]
+    val votes = new Array[Double](numClasses)
     _trees.view.foreach { tree =>
       val prediction = tree.rootNode.predict(features).toInt
-      votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight
+      votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
     }
-    votes.maxBy(_._2)._1
+    Vectors.dense(votes)
   }
 
   override def copy(extra: ParamMap): RandomForestClassificationModel = {
-    copyValues(new RandomForestClassificationModel(uid, _trees), extra)
+    copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra)
   }
 
   override def toString: String = {
@@ -185,7 +186,8 @@ private[ml] object RandomForestClassificationModel {
   def fromOld(
       oldModel: OldRandomForestModel,
       parent: RandomForestClassifier,
-      categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = {
+      categoricalFeatures: Map[Int, Int],
+      numClasses: Int): RandomForestClassificationModel = {
     require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
       s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
     val newTrees = oldModel.trees.map { tree =>
@@ -193,6 +195,6 @@ private[ml] object RandomForestClassificationModel {
       DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
     }
     val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
-    new RandomForestClassificationModel(uid, newTrees)
+    new RandomForestClassificationModel(uid, newTrees, numClasses)
   }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 1b6b69c7dc71ef038c65fcd6a88358eaa9962d6e..ab711c8e4b2151b3b61bb0192061a1acdc073019 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.impl.TreeTests
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.tree.LeafNode
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
 
 /**
  * Test suite for [[RandomForestClassifier]].
@@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
   test("params") {
     ParamsSuite.checkParams(new RandomForestClassifier)
     val model = new RandomForestClassificationModel("rfc",
-      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))))
+      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2)
     ParamsSuite.checkParams(model)
   }
 
@@ -167,9 +167,19 @@ private object RandomForestClassifierSuite {
     val newModel = rf.fit(newData)
     // Use parent from newTree since this is not checked anyways.
     val oldModelAsNew = RandomForestClassificationModel.fromOld(
-      oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
+      oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures,
+      numClasses)
     TreeTests.checkEqual(oldModelAsNew, newModel)
     assert(newModel.hasParent)
     assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
+    assert(newModel.numClasses == numClasses)
+    val results = newModel.transform(newData)
+    results.select("rawPrediction", "prediction").collect().foreach {
+      case Row(raw: Vector, prediction: Double) => {
+        assert(raw.size == numClasses)
+        val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2
+        assert(predFromRaw == prediction)
+      }
+    }
   }
 }
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 89117e492846bd6f66b83dfbfb2884eeca534266..5a82bc286d1e879175846b835c9479ddb6d4aad6 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -299,9 +299,9 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
     >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
     >>> si_model = stringIndexer.fit(df)
     >>> td = si_model.transform(df)
-    >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
+    >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42)
     >>> model = rf.fit(td)
-    >>> allclose(model.treeWeights, [1.0, 1.0])
+    >>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
     True
     >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction