From ef2f55b97f58fa06acb30e9e0172fb66fba383bc Mon Sep 17 00:00:00 2001
From: "Joseph K. Bradley" <joseph@databricks.com>
Date: Mon, 9 Feb 2015 22:09:07 -0800
Subject: [PATCH] [SPARK-5597][MLLIB] save/load for decision trees and
 emsembles

This is based on #4444 from jkbradley with the following changes:

1. Node schema updated to
   ~~~
treeId: int
nodeId: Int
predict/
       |- predict: Double
       |- prob: Double
impurity: Double
isLeaf: Boolean
split/
     |- feature: Int
     |- threshold: Double
     |- featureType: Int
     |- categories: Array[Double]
leftNodeId: Integer
rightNodeId: Integer
infoGain: Double
~~~

2. Some refactor of the implementation.

Closes #4444.

Author: Joseph K. Bradley <joseph@databricks.com>
Author: Xiangrui Meng <meng@databricks.com>

Closes #4493 from mengxr/SPARK-5597 and squashes the following commits:

75e3bb6 [Xiangrui Meng] fix style
2b0033d [Xiangrui Meng] update tree export schema and refactor the implementation
45873a2 [Joseph K. Bradley] org imports
1d4c264 [Joseph K. Bradley] Added save/load for tree ensembles
dcdbf85 [Joseph K. Bradley] added save/load for decision tree but need to generalize it to ensembles
---
 .../mllib/tree/model/DecisionTreeModel.scala  | 197 +++++++++++++++++-
 .../tree/model/InformationGainStats.scala     |   4 +-
 .../apache/spark/mllib/tree/model/Node.scala  |   5 +
 .../spark/mllib/tree/model/Predict.scala      |   7 +
 .../mllib/tree/model/treeEnsembleModels.scala | 157 +++++++++++++-
 .../spark/mllib/tree/DecisionTreeSuite.scala  | 120 ++++++++++-
 .../tree/GradientBoostedTreesSuite.scala      |  81 ++++---
 .../spark/mllib/tree/RandomForestSuite.scala  |  28 ++-
 8 files changed, 561 insertions(+), 38 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index a25e625a40..89ecf3773d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -17,11 +17,17 @@
 
 package org.apache.spark.mllib.tree.model
 
+import scala.collection.mutable
+
+import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType}
 import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
 
 /**
  * :: Experimental ::
@@ -31,7 +37,7 @@ import org.apache.spark.rdd.RDD
  * @param algo algorithm type -- classification or regression
  */
 @Experimental
-class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable {
+class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable {
 
   /**
    * Predict values for a single data point using the model trained.
@@ -98,4 +104,193 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
     header + topNode.subtreeToString(2)
   }
 
+  override def save(sc: SparkContext, path: String): Unit = {
+    DecisionTreeModel.SaveLoadV1_0.save(sc, path, this)
+  }
+
+  override protected def formatVersion: String = "1.0"
+}
+
+object DecisionTreeModel extends Loader[DecisionTreeModel] {
+
+  private[tree] object SaveLoadV1_0 {
+
+    def thisFormatVersion = "1.0"
+
+    // Hard-code class name string in case it changes in the future
+    def thisClassName = "org.apache.spark.mllib.tree.DecisionTreeModel"
+
+    case class PredictData(predict: Double, prob: Double) {
+      def toPredict: Predict = new Predict(predict, prob)
+    }
+
+    object PredictData {
+      def apply(p: Predict): PredictData = PredictData(p.predict, p.prob)
+
+      def apply(r: Row): PredictData = PredictData(r.getDouble(0), r.getDouble(1))
+    }
+
+    case class SplitData(
+        feature: Int,
+        threshold: Double,
+        featureType: Int,
+        categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed
+      def toSplit: Split = {
+        new Split(feature, threshold, FeatureType(featureType), categories.toList)
+      }
+    }
+
+    object SplitData {
+      def apply(s: Split): SplitData = {
+        SplitData(s.feature, s.threshold, s.featureType.id, s.categories)
+      }
+
+      def apply(r: Row): SplitData = {
+        SplitData(r.getInt(0), r.getDouble(1), r.getInt(2), r.getAs[Seq[Double]](3))
+      }
+    }
+
+    /** Model data for model import/export */
+    case class NodeData(
+        treeId: Int,
+        nodeId: Int,
+        predict: PredictData,
+        impurity: Double,
+        isLeaf: Boolean,
+        split: Option[SplitData],
+        leftNodeId: Option[Int],
+        rightNodeId: Option[Int],
+        infoGain: Option[Double])
+
+    object NodeData {
+      def apply(treeId: Int, n: Node): NodeData = {
+        NodeData(treeId, n.id, PredictData(n.predict), n.impurity, n.isLeaf,
+          n.split.map(SplitData.apply), n.leftNode.map(_.id), n.rightNode.map(_.id),
+          n.stats.map(_.gain))
+      }
+
+      def apply(r: Row): NodeData = {
+        val split = if (r.isNullAt(5)) None else Some(SplitData(r.getStruct(5)))
+        val leftNodeId = if (r.isNullAt(6)) None else Some(r.getInt(6))
+        val rightNodeId = if (r.isNullAt(7)) None else Some(r.getInt(7))
+        val infoGain = if (r.isNullAt(8)) None else Some(r.getDouble(8))
+        NodeData(r.getInt(0), r.getInt(1), PredictData(r.getStruct(2)), r.getDouble(3),
+          r.getBoolean(4), split, leftNodeId, rightNodeId, infoGain)
+      }
+    }
+
+    def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = {
+      val sqlContext = new SQLContext(sc)
+      import sqlContext.implicits._
+
+      // Create JSON metadata.
+      val metadataRDD = sc.parallelize(
+        Seq((thisClassName, thisFormatVersion, model.algo.toString, model.numNodes)), 1)
+        .toDataFrame("class", "version", "algo", "numNodes")
+      metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+
+      // Create Parquet data.
+      val nodes = model.topNode.subtreeIterator.toSeq
+      val dataRDD: DataFrame = sc.parallelize(nodes)
+        .map(NodeData.apply(0, _))
+        .toDataFrame
+      dataRDD.saveAsParquetFile(Loader.dataPath(path))
+    }
+
+    def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = {
+      val datapath = Loader.dataPath(path)
+      val sqlContext = new SQLContext(sc)
+      // Load Parquet data.
+      val dataRDD = sqlContext.parquetFile(datapath)
+      // Check schema explicitly since erasure makes it hard to use match-case for checking.
+      Loader.checkSchema[NodeData](dataRDD.schema)
+      val nodes = dataRDD.map(NodeData.apply)
+      // Build node data into a tree.
+      val trees = constructTrees(nodes)
+      assert(trees.size == 1,
+        "Decision tree should contain exactly one tree but got ${trees.size} trees.")
+      val model = new DecisionTreeModel(trees(0), Algo.fromString(algo))
+      assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $datapath." +
+        s" Expected $numNodes nodes but found ${model.numNodes}")
+      model
+    }
+
+    def constructTrees(nodes: RDD[NodeData]): Array[Node] = {
+      val trees = nodes
+        .groupBy(_.treeId)
+        .mapValues(_.toArray)
+        .collect()
+        .map { case (treeId, data) =>
+          (treeId, constructTree(data))
+        }.sortBy(_._1)
+      val numTrees = trees.size
+      val treeIndices = trees.map(_._1).toSeq
+      assert(treeIndices == (0 until numTrees),
+        s"Tree indices must start from 0 and increment by 1, but we found $treeIndices.")
+      trees.map(_._2)
+    }
+
+    /**
+     * Given a list of nodes from a tree, construct the tree.
+     * @param data array of all node data in a tree.
+     */
+    def constructTree(data: Array[NodeData]): Node = {
+      val dataMap: Map[Int, NodeData] = data.map(n => n.nodeId -> n).toMap
+      assert(dataMap.contains(1),
+        s"DecisionTree missing root node (id = 1).")
+      constructNode(1, dataMap, mutable.Map.empty)
+    }
+
+    /**
+     * Builds a node from the node data map and adds new nodes to the input nodes map.
+     */
+    private def constructNode(
+      id: Int,
+      dataMap: Map[Int, NodeData],
+      nodes: mutable.Map[Int, Node]): Node = {
+      if (nodes.contains(id)) {
+        return nodes(id)
+      }
+      val data = dataMap(id)
+      val node =
+        if (data.isLeaf) {
+          Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf)
+        } else {
+          val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes)
+          val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes)
+          val stats = new InformationGainStats(data.infoGain.get, data.impurity, leftNode.impurity,
+            rightNode.impurity, leftNode.predict, rightNode.predict)
+          new Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf,
+            data.split.map(_.toSplit), Some(leftNode), Some(rightNode), Some(stats))
+        }
+      nodes += node.id -> node
+      node
+    }
+  }
+
+  override def load(sc: SparkContext, path: String): DecisionTreeModel = {
+    val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+    val (algo: String, numNodes: Int) = try {
+      val algo_numNodes = metadata.select("algo", "numNodes").collect()
+      assert(algo_numNodes.length == 1)
+      algo_numNodes(0) match {
+        case Row(a: String, n: Int) => (a, n)
+      }
+    } catch {
+      // Catch both Error and Exception since the checks above can throw either.
+      case e: Throwable =>
+        throw new Exception(
+          s"Unable to load DecisionTreeModel metadata from: ${Loader.metadataPath(path)}."
+          + s"  Error message: ${e.getMessage}")
+    }
+    val classNameV1_0 = SaveLoadV1_0.thisClassName
+    (loadedClassName, version) match {
+      case (className, "1.0") if className == classNameV1_0 =>
+        SaveLoadV1_0.load(sc, path, algo, numNodes)
+      case _ => throw new Exception(
+        s"DecisionTreeModel.load did not recognize model with (className, format version):" +
+        s"($loadedClassName, $version).  Supported:\n" +
+        s"  ($classNameV1_0, 1.0)")
+    }
+  }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index 9a50ecb550..80990aa9a6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -49,7 +49,9 @@ class InformationGainStats(
         gain == other.gain &&
         impurity == other.impurity &&
         leftImpurity == other.leftImpurity &&
-        rightImpurity == other.rightImpurity
+        rightImpurity == other.rightImpurity &&
+        leftPredict == other.leftPredict &&
+        rightPredict == other.rightPredict
       }
       case _ => false
     }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 2179da8dbe..d961081d18 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -166,6 +166,11 @@ class Node (
     }
   }
 
+  /** Returns an iterator that traverses (DFS, left to right) the subtree of this node. */
+  private[tree] def subtreeIterator: Iterator[Node] = {
+    Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++
+      rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty)
+  }
 }
 
 private[tree] object Node {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index 004838ee5b..ad4c0dbbfb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -32,4 +32,11 @@ class Predict(
   override def toString = {
     "predict = %f, prob = %f".format(predict, prob)
   }
+
+  override def equals(other: Any): Boolean = {
+    other match {
+      case p: Predict => predict == p.predict && prob == p.prob
+      case _ => false
+    }
+  }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 22997110de..23bd46baab 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -21,12 +21,17 @@ import scala.collection.mutable
 
 import com.github.fommil.netlib.BLAS.{getInstance => blas}
 
+import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.Algo
 import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.mllib.util.{Saveable, Loader}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
 
 /**
  * :: Experimental ::
@@ -38,9 +43,42 @@ import org.apache.spark.rdd.RDD
 @Experimental
 class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
   extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0),
-    combiningStrategy = if (algo == Classification) Vote else Average) {
+    combiningStrategy = if (algo == Classification) Vote else Average)
+  with Saveable {
 
   require(trees.forall(_.algo == algo))
+
+  override def save(sc: SparkContext, path: String): Unit = {
+    TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
+      RandomForestModel.SaveLoadV1_0.thisClassName)
+  }
+
+  override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+}
+
+object RandomForestModel extends Loader[RandomForestModel] {
+
+  override def load(sc: SparkContext, path: String): RandomForestModel = {
+    val (loadedClassName, version, metadataRDD) = Loader.loadMetadata(sc, path)
+    val classNameV1_0 = SaveLoadV1_0.thisClassName
+    (loadedClassName, version) match {
+      case (className, "1.0") if className == classNameV1_0 =>
+        val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(metadataRDD, path)
+        assert(metadata.treeWeights.forall(_ == 1.0))
+        val trees =
+          TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)
+        new RandomForestModel(Algo.fromString(metadata.algo), trees)
+      case _ => throw new Exception(s"RandomForestModel.load did not recognize model" +
+        s" with (className, format version): ($loadedClassName, $version).  Supported:\n" +
+        s"  ($classNameV1_0, 1.0)")
+    }
+  }
+
+  private object SaveLoadV1_0 {
+    // Hard-code class name string in case it changes in the future
+    def thisClassName = "org.apache.spark.mllib.tree.model.RandomForestModel"
+  }
+
 }
 
 /**
@@ -56,9 +94,42 @@ class GradientBoostedTreesModel(
     override val algo: Algo,
     override val trees: Array[DecisionTreeModel],
     override val treeWeights: Array[Double])
-  extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) {
+  extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum)
+  with Saveable {
 
   require(trees.size == treeWeights.size)
+
+  override def save(sc: SparkContext, path: String): Unit = {
+    TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
+      GradientBoostedTreesModel.SaveLoadV1_0.thisClassName)
+  }
+
+  override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+}
+
+object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
+
+  override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
+    val (loadedClassName, version, metadataRDD) = Loader.loadMetadata(sc, path)
+    val classNameV1_0 = SaveLoadV1_0.thisClassName
+    (loadedClassName, version) match {
+      case (className, "1.0") if className == classNameV1_0 =>
+        val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(metadataRDD, path)
+        assert(metadata.combiningStrategy == Sum.toString)
+        val trees =
+          TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)
+        new GradientBoostedTreesModel(Algo.fromString(metadata.algo), trees, metadata.treeWeights)
+      case _ => throw new Exception(s"GradientBoostedTreesModel.load did not recognize model" +
+        s" with (className, format version): ($loadedClassName, $version).  Supported:\n" +
+        s"  ($classNameV1_0, 1.0)")
+    }
+  }
+
+  private object SaveLoadV1_0 {
+    // Hard-code class name string in case it changes in the future
+    def thisClassName = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"
+  }
+
 }
 
 /**
@@ -176,3 +247,85 @@ private[tree] sealed class TreeEnsembleModel(
    */
   def totalNumNodes: Int = trees.map(_.numNodes).sum
 }
+
+private[tree] object TreeEnsembleModel {
+
+  object SaveLoadV1_0 {
+
+    import DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees}
+
+    def thisFormatVersion = "1.0"
+
+    case class Metadata(
+        algo: String,
+        treeAlgo: String,
+        combiningStrategy: String,
+        treeWeights: Array[Double])
+
+    /**
+     * Model data for model import/export.
+     * We have to duplicate NodeData here since Spark SQL does not yet support extracting subfields
+     * of nested fields; once that is possible, we can use something like:
+     *  case class EnsembleNodeData(treeId: Int, node: NodeData),
+     *  where NodeData is from DecisionTreeModel.
+     */
+    case class EnsembleNodeData(treeId: Int, node: NodeData)
+
+    def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = {
+      val sqlContext = new SQLContext(sc)
+      import sqlContext.implicits._
+
+      // Create JSON metadata.
+      val metadata = Metadata(model.algo.toString, model.trees(0).algo.toString,
+        model.combiningStrategy.toString, model.treeWeights)
+      val metadataRDD = sc.parallelize(Seq((className, thisFormatVersion, metadata)), 1)
+        .toDataFrame("class", "version", "metadata")
+      metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+
+      // Create Parquet data.
+      val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) =>
+        tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node))
+      }.toDataFrame
+      dataRDD.saveAsParquetFile(Loader.dataPath(path))
+    }
+
+    /**
+     * Read metadata from the loaded metadata DataFrame.
+     * @param path  Path for loading data, used for debug messages.
+     */
+    def readMetadata(metadata: DataFrame, path: String): Metadata = {
+      try {
+        // We rely on the try-catch for schema checking rather than creating a schema just for this.
+        val metadataArray = metadata.select("metadata.algo", "metadata.treeAlgo",
+          "metadata.combiningStrategy", "metadata.treeWeights").collect()
+        assert(metadataArray.size == 1)
+        Metadata(metadataArray(0).getString(0), metadataArray(0).getString(1),
+          metadataArray(0).getString(2), metadataArray(0).getAs[Seq[Double]](3).toArray)
+      } catch {
+        // Catch both Error and Exception since the checks above can throw either.
+        case e: Throwable =>
+          throw new Exception(
+            s"Unable to load TreeEnsembleModel metadata from: ${Loader.metadataPath(path)}."
+              + s"  Error message: ${e.getMessage}")
+      }
+    }
+
+    /**
+     * Load trees for an ensemble, and return them in order.
+     * @param path path to load the model from
+     * @param treeAlgo Algorithm for individual trees (which may differ from the ensemble's
+     *                 algorithm).
+     */
+    def loadTrees(
+        sc: SparkContext,
+        path: String,
+        treeAlgo: String): Array[DecisionTreeModel] = {
+      val datapath = Loader.dataPath(path)
+      val sqlContext = new SQLContext(sc)
+      val nodes = sqlContext.parquetFile(datapath).map(NodeData.apply)
+      val trees = constructTrees(nodes)
+      trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo)))
+    }
+  }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 9347eaf922..7b1aed5ffe 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -29,8 +29,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy}
 import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
+import org.apache.spark.mllib.tree.model._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
 
 class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
 
@@ -857,9 +859,32 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
     assert(topNode.leftNode.get.impurity === 0.0)
     assert(topNode.rightNode.get.impurity === 0.0)
   }
+
+  test("Node.subtreeIterator") {
+    val model = DecisionTreeSuite.createModel(Classification)
+    val nodeIds = model.topNode.subtreeIterator.map(_.id).toArray.sorted
+    assert(nodeIds === DecisionTreeSuite.createdModelNodeIds)
+  }
+
+  test("model save/load") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    Array(Classification, Regression).foreach { algo =>
+      val model = DecisionTreeSuite.createModel(algo)
+      // Save model, load it back, and compare.
+      try {
+        model.save(sc, path)
+        val sameModel = DecisionTreeModel.load(sc, path)
+        DecisionTreeSuite.checkEqual(model, sameModel)
+      } finally {
+        Utils.deleteRecursively(tempDir)
+      }
+    }
+  }
 }
 
-object DecisionTreeSuite {
+object DecisionTreeSuite extends FunSuite {
 
   def validateClassifier(
       model: DecisionTreeModel,
@@ -979,4 +1004,95 @@ object DecisionTreeSuite {
     arr
   }
 
+  /** Create a leaf node with the given node ID */
+  private def createLeafNode(id: Int): Node = {
+    Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = true)
+  }
+
+  /**
+   * Create an internal node with the given node ID and feature type.
+   * Note: This does NOT set the child nodes.
+   */
+  private def createInternalNode(id: Int, featureType: FeatureType): Node = {
+    val node = Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = false)
+    featureType match {
+      case Continuous =>
+        node.split = Some(new Split(feature = 0, threshold = 0.5, Continuous,
+          categories = List.empty[Double]))
+      case Categorical =>
+        node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical,
+          categories = List(0.0, 1.0)))
+    }
+    // TODO: The information gain stats should be consistent with the same info stored in children.
+    node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2,
+      leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6)))
+    node
+  }
+
+  /**
+   * Create a tree model.  This is deterministic and contains a variety of node and feature types.
+   */
+  private[tree] def createModel(algo: Algo): DecisionTreeModel = {
+    val topNode = createInternalNode(id = 1, Continuous)
+    val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical))
+    val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7))
+    topNode.leftNode = Some(node2)
+    topNode.rightNode = Some(node3)
+    node3.leftNode = Some(node6)
+    node3.rightNode = Some(node7)
+    new DecisionTreeModel(topNode, algo)
+  }
+
+  /** Sorted Node IDs matching the model returned by [[createModel()]] */
+  private val createdModelNodeIds = Array(1, 2, 3, 6, 7)
+
+  /**
+   * Check if the two trees are exactly the same.
+   * Note: I hesitate to override Node.equals since it could cause problems if users
+   *       make mistakes such as creating loops of Nodes.
+   * If the trees are not equal, this prints the two trees and throws an exception.
+   */
+  private[tree] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
+    try {
+      assert(a.algo === b.algo)
+      checkEqual(a.topNode, b.topNode)
+    } catch {
+      case ex: Exception =>
+        throw new AssertionError("checkEqual failed since the two trees were not identical.\n" +
+          "TREE A:\n" + a.toDebugString + "\n" +
+          "TREE B:\n" + b.toDebugString + "\n", ex)
+    }
+  }
+
+  /**
+   * Return true iff the two nodes and their descendents are exactly the same.
+   * Note: I hesitate to override Node.equals since it could cause problems if users
+   *       make mistakes such as creating loops of Nodes.
+   */
+  private def checkEqual(a: Node, b: Node): Unit = {
+    assert(a.id === b.id)
+    assert(a.predict === b.predict)
+    assert(a.impurity === b.impurity)
+    assert(a.isLeaf === b.isLeaf)
+    assert(a.split === b.split)
+    (a.stats, b.stats) match {
+      // TODO: Check other fields besides the infomation gain.
+      case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain)
+      case (None, None) =>
+      case _ => throw new AssertionError(
+          s"Only one instance has stats defined. (a.stats: ${a.stats}, b.stats: ${b.stats})")
+    }
+    (a.leftNode, b.leftNode) match {
+      case (Some(aNode), Some(bNode)) => checkEqual(aNode, bNode)
+      case (None, None) =>
+      case _ => throw new AssertionError("Only one instance has leftNode defined. " +
+        s"(a.leftNode: ${a.leftNode}, b.leftNode: ${b.leftNode})")
+    }
+    (a.rightNode, b.rightNode) match {
+      case (Some(aNode: Node), Some(bNode: Node)) => checkEqual(aNode, bNode)
+      case (None, None) =>
+      case _ => throw new AssertionError("Only one instance has rightNode defined. " +
+        s"(a.rightNode: ${a.rightNode}, b.rightNode: ${b.rightNode})")
+    }
+  }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index e8341a5d0d..bde47606eb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -24,8 +24,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
 import org.apache.spark.mllib.tree.impurity.Variance
 import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss}
-
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
 
 /**
  * Test suite for [[GradientBoostedTrees]].
@@ -35,32 +37,30 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
   test("Regression with continuous features: SquaredError") {
     GradientBoostedTreesSuite.testCombinations.foreach {
       case (numIterations, learningRate, subsamplingRate) =>
-        GradientBoostedTreesSuite.randomSeeds.foreach { randomSeed =>
-          val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
-
-          val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
-            categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
-          val boostingStrategy =
-            new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
-
-          val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
-
-          assert(gbt.trees.size === numIterations)
-          try {
-            EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
-          } catch {
-            case e: java.lang.AssertionError =>
-              println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
-                s" subsamplingRate=$subsamplingRate")
-              throw e
-          }
-
-          val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
-          val dt = DecisionTree.train(remappedInput, treeStrategy)
-
-          // Make sure trees are the same.
-          assert(gbt.trees.head.toString == dt.toString)
+        val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
+
+        val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+          categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
+        val boostingStrategy =
+          new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
+
+        val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
+
+        assert(gbt.trees.size === numIterations)
+        try {
+          EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
+        } catch {
+          case e: java.lang.AssertionError =>
+            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
+              s" subsamplingRate=$subsamplingRate")
+            throw e
         }
+
+        val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+        val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+        // Make sure trees are the same.
+        assert(gbt.trees.head.toString == dt.toString)
     }
   }
 
@@ -133,14 +133,37 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
       BoostingStrategy.defaultParams(algo)
     }
   }
+
+  test("model save/load") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(Regression)).toArray
+    val treeWeights = Array(0.1, 0.3, 1.1)
+
+    Array(Classification, Regression).foreach { algo =>
+      val model = new GradientBoostedTreesModel(algo, trees, treeWeights)
+
+      // Save model, load it back, and compare.
+      try {
+        model.save(sc, path)
+        val sameModel = GradientBoostedTreesModel.load(sc, path)
+        assert(model.algo == sameModel.algo)
+        model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) =>
+          DecisionTreeSuite.checkEqual(treeA, treeB)
+        }
+        assert(model.treeWeights === sameModel.treeWeights)
+      } finally {
+        Utils.deleteRecursively(tempDir)
+      }
+    }
+  }
 }
 
-object GradientBoostedTreesSuite {
+private object GradientBoostedTreesSuite {
 
   // Combinations for estimators, learning rates and subsamplingRate
   val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
 
-  val randomSeeds = Array(681283, 4398)
-
   val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 55e963977b..ee3bc98486 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -27,8 +27,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
 import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
-import org.apache.spark.mllib.tree.model.Node
+import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
 
 /**
  * Test suite for [[RandomForest]].
@@ -212,6 +214,26 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
     assert(rf1.toDebugString != rf2.toDebugString)
   }
 
-}
-
+  test("model save/load") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+
+    Array(Classification, Regression).foreach { algo =>
+      val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(algo)).toArray
+      val model = new RandomForestModel(algo, trees)
+
+      // Save model, load it back, and compare.
+      try {
+        model.save(sc, path)
+        val sameModel = RandomForestModel.load(sc, path)
+        assert(model.algo == sameModel.algo)
+        model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) =>
+          DecisionTreeSuite.checkEqual(treeA, treeB)
+        }
+      } finally {
+        Utils.deleteRecursively(tempDir)
+      }
+    }
+  }
 
+}
-- 
GitLab