Skip to content
Snippets Groups Projects
Commit 9893dc97 authored by Zheng RuiFeng's avatar Zheng RuiFeng Committed by Sean Owen
Browse files

[SPARK-15610][ML] update error message for k in pca

## What changes were proposed in this pull request?
Fix the wrong bound of `k` in `PCA`
`require(k <= sources.first().size, ...`  ->  `require(k < sources.first().size`

BTW, remove unused import in `ml.ElementwiseProduct`

## How was this patch tested?

manual tests

Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes #13356 from zhengruifeng/fix_pca.
parent 88c9c467
No related branches found
No related tags found
No related merge requests found
......@@ -23,7 +23,6 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.sql.types.DataType
......
......@@ -40,8 +40,9 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
*/
@Since("1.4.0")
def fit(sources: RDD[Vector]): PCAModel = {
require(k <= sources.first().size,
s"source vector size is ${sources.first().size} must be greater than k=$k")
val numFeatures = sources.first().size
require(k <= numFeatures,
s"source vector size $numFeatures must be no less than k=$k")
val mat = new RowMatrix(sources)
val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k)
......@@ -58,7 +59,6 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
case m =>
throw new IllegalArgumentException("Unsupported matrix format. Expected " +
s"SparseMatrix or DenseMatrix. Instead got: ${m.getClass}")
}
val denseExplainedVariance = explainedVariance match {
case dv: DenseVector =>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment