Skip to content
Snippets Groups Projects
Commit 83302c3b authored by Xusen Yin's avatar Xusen Yin Committed by Xiangrui Meng
Browse files

[SPARK-13036][SPARK-13318][SPARK-13319] Add save/load for feature.py

Add save/load for feature.py. Meanwhile, add save/load for `ElementwiseProduct` in Scala side and fix a bug of missing `setDefault` in `VectorSlicer` and `StopWordsRemover`.

In this PR I ignore the `RFormula` and `RFormulaModel` because its Scala implementation is pending in https://github.com/apache/spark/pull/9884. I'll add them in this PR if https://github.com/apache/spark/pull/9884 gets merged first. Or add a follow-up JIRA for `RFormula`.

Author: Xusen Yin <yinxusen@gmail.com>

Closes #11203 from yinxusen/SPARK-13036.
parent c8f25459
No related branches found
No related tags found
No related merge requests found
......@@ -17,10 +17,10 @@
package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
......@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DataType
*/
@Experimental
class ElementwiseProduct(override val uid: String)
extends UnaryTransformer[Vector, Vector, ElementwiseProduct] {
extends UnaryTransformer[Vector, Vector, ElementwiseProduct] with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("elemProd"))
......@@ -57,3 +57,10 @@ class ElementwiseProduct(override val uid: String)
override protected def outputDataType: DataType = new VectorUDT()
}
@Since("2.0.0")
object ElementwiseProduct extends DefaultParamsReadable[ElementwiseProduct] {
@Since("2.0.0")
override def load(path: String): ElementwiseProduct = super.load(path)
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
class ElementwiseProductSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("read/write") {
val ep = new ElementwiseProduct()
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setScalingVec(Vectors.dense(0.1, 0.2))
testDefaultReadWrite(ep)
}
}
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment