Skip to content
Snippets Groups Projects
Commit 44948a2e authored by Holden Karau's avatar Holden Karau Committed by Xiangrui Meng
Browse files

[SPARK-9723] [ML] params getordefault should throw more useful error

Params.getOrDefault should throw a more meaningful exception than what you get from a bad key lookup.

Author: Holden Karau <holden@pigscanfly.ca>

Closes #8567 from holdenk/SPARK-9723-params-getordefault-should-throw-more-useful-error.
parent 03f3e91f
No related branches found
No related tags found
No related merge requests found
...@@ -461,7 +461,8 @@ trait Params extends Identifiable with Serializable { ...@@ -461,7 +461,8 @@ trait Params extends Identifiable with Serializable {
*/ */
final def getOrDefault[T](param: Param[T]): T = { final def getOrDefault[T](param: Param[T]): T = {
shouldOwn(param) shouldOwn(param)
get(param).orElse(getDefault(param)).get get(param).orElse(getDefault(param)).getOrElse(
throw new NoSuchElementException(s"Failed to find a default value for ${param.name}"))
} }
/** An alias for [[getOrDefault()]]. */ /** An alias for [[getOrDefault()]]. */
......
...@@ -40,6 +40,10 @@ class ParamsSuite extends SparkFunSuite { ...@@ -40,6 +40,10 @@ class ParamsSuite extends SparkFunSuite {
assert(inputCol.toString === s"${uid}__inputCol") assert(inputCol.toString === s"${uid}__inputCol")
intercept[java.util.NoSuchElementException] {
solver.getOrDefault(solver.handleInvalid)
}
intercept[IllegalArgumentException] { intercept[IllegalArgumentException] {
solver.setMaxIter(-1) solver.setMaxIter(-1)
} }
...@@ -102,12 +106,13 @@ class ParamsSuite extends SparkFunSuite { ...@@ -102,12 +106,13 @@ class ParamsSuite extends SparkFunSuite {
test("params") { test("params") {
val solver = new TestParams() val solver = new TestParams()
import solver.{maxIter, inputCol} import solver.{handleInvalid, maxIter, inputCol}
val params = solver.params val params = solver.params
assert(params.length === 2) assert(params.length === 3)
assert(params(0).eq(inputCol), "params must be ordered by name") assert(params(0).eq(handleInvalid), "params must be ordered by name")
assert(params(1).eq(maxIter)) assert(params(1).eq(inputCol), "params must be ordered by name")
assert(params(2).eq(maxIter))
assert(!solver.isSet(maxIter)) assert(!solver.isSet(maxIter))
assert(solver.isDefined(maxIter)) assert(solver.isDefined(maxIter))
...@@ -122,7 +127,7 @@ class ParamsSuite extends SparkFunSuite { ...@@ -122,7 +127,7 @@ class ParamsSuite extends SparkFunSuite {
assert(solver.explainParam(maxIter) === assert(solver.explainParam(maxIter) ===
"maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)") "maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)")
assert(solver.explainParams() === assert(solver.explainParams() ===
Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n")) Seq(handleInvalid, inputCol, maxIter).map(solver.explainParam).mkString("\n"))
assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("inputCol").eq(inputCol))
assert(solver.getParam("maxIter").eq(maxIter)) assert(solver.getParam("maxIter").eq(maxIter))
......
...@@ -17,11 +17,12 @@ ...@@ -17,11 +17,12 @@
package org.apache.spark.ml.param package org.apache.spark.ml.param
import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasMaxIter}
import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util.Identifiable
/** A subclass of Params for testing. */ /** A subclass of Params for testing. */
class TestParams(override val uid: String) extends Params with HasMaxIter with HasInputCol { class TestParams(override val uid: String) extends Params with HasHandleInvalid with HasMaxIter
with HasInputCol {
def this() = this(Identifiable.randomUID("testParams")) def this() = this(Identifiable.randomUID("testParams"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment