-
- Downloads
[SPARK-3158][MLLIB]Avoid 1 extra aggregation for DecisionTree training
Currently, the implementation does one unnecessary aggregation step. The aggregation step for level L (to choose splits) gives enough information to set the predictions of any leaf nodes at level L+1. We can use that info and skip the aggregation step for the last level of the tree (which only has leaf nodes). ### Implementation Details Each node now has a `impurity` field and the `predict` is changed from type `Double` to type `Predict`(this can be used to compute predict probability in the future) When compute best splits for each node, we also compute impurity and predict for the child nodes, which is used to constructed newly allocated child nodes. So at level L, we have set impurity and predict for nodes at level L +1. If level L+1 is the last level, then we can avoid aggregation. What's more, calculation of parent impurity in Top nodes for each tree needs to be treated differently because we have to compute impurity and predict for them first. In `binsToBestSplit`, if current node is top node(level == 0), we calculate impurity and predict first. after finding best split, top node's predict and impurity is set to the calculated value. Non-top nodes's impurity and predict are already calculated and don't need to be recalculated again. I have considered to add a initialization step to set top nodes' impurity and predict and then we can treat all nodes in the same way, but this will need a lot of duplication of code(all the code to do seq operation(BinSeqOp) needs to be duplicated), so I choose the current way. CC mengxr manishamde jkbradley, please help me review this, thanks. Author: Qiping Li <liqiping1991@gmail.com> Closes #2708 from chouqin/avoid-agg and squashes the following commits: 8e269ea [Qiping Li] adjust code and comments eefeef1 [Qiping Li] adjust comments and check child nodes' impurity c41b1b6 [Qiping Li] fix pyspark unit test 7ad7a71 [Qiping Li] fix unit test 822c912 [Qiping Li] add comments and unit test e41d715 [Qiping Li] fix bug in test suite 6cc0333 [Qiping Li] SPARK-3158: Avoid 1 extra aggregation for DecisionTree training
Showing
- mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala 66 additions, 31 deletions...main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
- mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala 7 additions, 2 deletions.../apache/spark/mllib/tree/model/InformationGainStats.scala
- mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala 30 additions, 7 deletions...c/main/scala/org/apache/spark/mllib/tree/model/Node.scala
- mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala 94 additions, 8 deletions...scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
Loading
Please register or sign in to comment