Skip to content
Snippets Groups Projects
Commit 643b4e22 authored by Zheng RuiFeng's avatar Zheng RuiFeng Committed by Xiangrui Meng
Browse files

[SPARK-14510][MLLIB] Add args-checking for LDA and StreamingKMeans

## What changes were proposed in this pull request?
add the checking for LDA and StreamingKMeans

## How was this patch tested?
manual tests

Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes #12062 from zhengruifeng/initmodel.
parent 1c751fcf
No related branches found
No related tags found
No related merge requests found
...@@ -130,7 +130,8 @@ class LDA private ( ...@@ -130,7 +130,8 @@ class LDA private (
*/ */
@Since("1.5.0") @Since("1.5.0")
def setDocConcentration(docConcentration: Vector): this.type = { def setDocConcentration(docConcentration: Vector): this.type = {
require(docConcentration.size > 0, "docConcentration must have > 0 elements") require(docConcentration.size == 1 || docConcentration.size == k,
s"Size of docConcentration must be 1 or ${k} but got ${docConcentration.size}")
this.docConcentration = docConcentration this.docConcentration = docConcentration
this this
} }
...@@ -260,15 +261,18 @@ class LDA private ( ...@@ -260,15 +261,18 @@ class LDA private (
def getCheckpointInterval: Int = checkpointInterval def getCheckpointInterval: Int = checkpointInterval
/** /**
* Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery * Parameter for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that
* the cache will get checkpointed every 10 iterations. Checkpointing helps with recovery
* (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be * (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be
* important when LDA is run for many iterations. If the checkpoint directory is not set in * important when LDA is run for many iterations. If the checkpoint directory is not set in
* [[org.apache.spark.SparkContext]], this setting is ignored. * [[org.apache.spark.SparkContext]], this setting is ignored. (default = 10)
* *
* @see [[org.apache.spark.SparkContext#setCheckpointDir]] * @see [[org.apache.spark.SparkContext#setCheckpointDir]]
*/ */
@Since("1.3.0") @Since("1.3.0")
def setCheckpointInterval(checkpointInterval: Int): this.type = { def setCheckpointInterval(checkpointInterval: Int): this.type = {
require(checkpointInterval == -1 || checkpointInterval > 0,
s"Period between checkpoints must be -1 or positive but got ${checkpointInterval}")
this.checkpointInterval = checkpointInterval this.checkpointInterval = checkpointInterval
this this
} }
......
...@@ -218,6 +218,12 @@ class StreamingKMeans @Since("1.2.0") ( ...@@ -218,6 +218,12 @@ class StreamingKMeans @Since("1.2.0") (
*/ */
@Since("1.2.0") @Since("1.2.0")
def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
require(centers.size == weights.size,
"Number of initial centers must be equal to number of weights")
require(centers.size == k,
s"Number of initial centers must be ${k} but got ${centers.size}")
require(weights.forall(_ >= 0),
s"Weight for each inital center must be nonnegative but got [${weights.mkString(" ")}]")
model = new StreamingKMeansModel(centers, weights) model = new StreamingKMeansModel(centers, weights)
this this
} }
...@@ -231,6 +237,10 @@ class StreamingKMeans @Since("1.2.0") ( ...@@ -231,6 +237,10 @@ class StreamingKMeans @Since("1.2.0") (
*/ */
@Since("1.2.0") @Since("1.2.0")
def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
require(dim > 0,
s"Number of dimensions must be positive but got ${dim}")
require(weight >= 0,
s"Weight for each center must be nonnegative but got ${weight}")
val random = new XORShiftRandom(seed) val random = new XORShiftRandom(seed)
val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian()))) val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
val weights = Array.fill(k)(weight) val weights = Array.fill(k)(weight)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment