Skip to content
Snippets Groups Projects
Commit d15b4f90 authored by WeichenXu's avatar WeichenXu Committed by Sean Owen
Browse files

[SPARK-17507][ML][MLLIB] check weight vector size in ANN

## What changes were proposed in this pull request?

as the TODO described,
check weight vector size and if wrong throw exception.

## How was this patch tested?

existing tests.

Author: WeichenXu <WeichenXu123@outlook.com>

Closes #15060 from WeichenXu123/check_input_weight_size_of_ann.
parent 6a6adb16
No related branches found
No related tags found
No related merge requests found
......@@ -545,7 +545,9 @@ private[ann] object FeedForwardModel {
* @return model
*/
def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = {
// TODO: check that weights size is equal to sum of layers sizes
val expectedWeightSize = topology.layers.map(_.weightSize).sum
require(weights.size == expectedWeightSize,
s"Expected weight vector of size ${expectedWeightSize} but got size ${weights.size}.")
new FeedForwardModel(weights, topology)
}
......@@ -559,11 +561,7 @@ private[ann] object FeedForwardModel {
def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = {
val layers = topology.layers
val layerModels = new Array[LayerModel](layers.length)
var totalSize = 0
for (i <- 0 until topology.layers.length) {
totalSize += topology.layers(i).weightSize
}
val weights = BDV.zeros[Double](totalSize)
val weights = BDV.zeros[Double](topology.layers.map(_.weightSize).sum)
var offset = 0
val random = new XORShiftRandom(seed)
for (i <- 0 until layers.length) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment