Skip to content
Snippets Groups Projects
Commit 0ebf7c1b authored by zero323's avatar zero323 Committed by Sean Owen
Browse files

[SPARK-17027][ML] Avoid integer overflow in PolynomialExpansion.getPolySize

## What changes were proposed in this pull request?

Replaces custom choose function with o.a.commons.math3.CombinatoricsUtils.binomialCoefficient

## How was this patch tested?

Spark unit tests

Author: zero323 <zero323@users.noreply.github.com>

Closes #14614 from zero323/SPARK-17027.
parent cdaa562c
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,8 @@ package org.apache.spark.ml.feature
import scala.collection.mutable
import org.apache.commons.math3.util.CombinatoricsUtils
import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.linalg._
......@@ -84,12 +86,12 @@ class PolynomialExpansion @Since("1.4.0") (@Since("1.4.0") override val uid: Str
@Since("1.6.0")
object PolynomialExpansion extends DefaultParamsReadable[PolynomialExpansion] {
private def choose(n: Int, k: Int): Int = {
Range(n, n - k, -1).product / Range(k, 1, -1).product
private def getPolySize(numFeatures: Int, degree: Int): Int = {
val n = CombinatoricsUtils.binomialCoefficient(numFeatures + degree, degree)
require(n <= Integer.MAX_VALUE)
n.toInt
}
private def getPolySize(numFeatures: Int, degree: Int): Int = choose(numFeatures + degree, degree)
private def expandDense(
values: Array[Double],
lastIdx: Int,
......
......@@ -116,5 +116,29 @@ class PolynomialExpansionSuite
.setDegree(3)
testDefaultReadWrite(t)
}
test("SPARK-17027. Integer overflow in PolynomialExpansion.getPolySize") {
val data: Array[(Vector, Int, Int)] = Array(
(Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0), 3002, 4367),
(Vectors.sparse(5, Seq((0, 1.0), (4, 5.0))), 3002, 4367),
(Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0), 8007, 12375)
)
val df = spark.createDataFrame(data)
.toDF("features", "expectedPoly10size", "expectedPoly11size")
val t = new PolynomialExpansion()
.setInputCol("features")
.setOutputCol("polyFeatures")
for (i <- Seq(10, 11)) {
val transformed = t.setDegree(i)
.transform(df)
.select(s"expectedPoly${i}size", "polyFeatures")
.rdd.map { case Row(expected: Int, v: Vector) => expected == v.size }
assert(transformed.collect.forall(identity))
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment