Skip to content
Snippets Groups Projects
Commit 07165ca0 authored by Kousuke Saruta's avatar Kousuke Saruta
Browse files

[SPARK-12424][ML] The implementation of ParamMap#filter is wrong.

ParamMap#filter uses `mutable.Map#filterKeys`. The return type of `filterKey` is collection.Map, not mutable.Map but the result is casted to mutable.Map using `asInstanceOf` so we get `ClassCastException`.
Also, the return type of Map#filterKeys is not Serializable. It's the issue of Scala (https://issues.scala-lang.org/browse/SI-6654).

Author: Kousuke Saruta <sarutak@oss.nttdata.co.jp>

Closes #10381 from sarutak/SPARK-12424.
parent e01c6c86
No related branches found
No related tags found
No related merge requests found
......@@ -859,8 +859,12 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
* Filters this param map for the given parent.
*/
def filter(parent: Params): ParamMap = {
val filtered = map.filterKeys(_.parent == parent)
new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]])
// Don't use filterKeys because mutable.Map#filterKeys
// returns the instance of collections.Map, not mutable.Map.
// Otherwise, we get ClassCastException.
// Not using filterKeys also avoid SI-6654
val filtered = map.filter { case (k, _) => k.parent == parent.uid }
new ParamMap(filtered)
}
/**
......
......@@ -17,7 +17,10 @@
package org.apache.spark.ml.param
import java.io.{ByteArrayOutputStream, NotSerializableException, ObjectOutputStream}
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.MyParams
import org.apache.spark.mllib.linalg.{Vector, Vectors}
class ParamsSuite extends SparkFunSuite {
......@@ -349,6 +352,31 @@ class ParamsSuite extends SparkFunSuite {
val t3 = t.copy(ParamMap(t.maxIter -> 20))
assert(t3.isSet(t3.maxIter))
}
test("Filtering ParamMap") {
val params1 = new MyParams("my_params1")
val params2 = new MyParams("my_params2")
val paramMap = ParamMap(
params1.intParam -> 1,
params2.intParam -> 1,
params1.doubleParam -> 0.2,
params2.doubleParam -> 0.2)
val filteredParamMap = paramMap.filter(params1)
assert(filteredParamMap.size === 2)
filteredParamMap.toSeq.foreach {
case ParamPair(p, _) =>
assert(p.parent === params1.uid)
}
// At the previous implementation of ParamMap#filter,
// mutable.Map#filterKeys was used internally but
// the return type of the method is not serializable (see SI-6654).
// Now mutable.Map#filter is used instead of filterKeys and the return type is serializable.
// So let's ensure serializability.
val objOut = new ObjectOutputStream(new ByteArrayOutputStream())
objOut.writeObject(filteredParamMap)
}
}
object ParamsSuite extends SparkFunSuite {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment