Skip to content
Snippets Groups Projects
Commit 186bf8fb authored by Bago Amirbekian's avatar Bago Amirbekian Committed by Joseph K. Bradley
Browse files

[SPARK-23046][ML][SPARKR] Have RFormula include VectorSizeHint in pipeline

## What changes were proposed in this pull request?

Including VectorSizeHint in RFormula piplelines will allow them to be applied to streaming dataframes.

## How was this patch tested?

Unit tests.

Author: Bago Amirbekian <bago@databricks.com>

Closes #20238 from MrBago/rFormulaVectorSize.
parent 6f7aaed8
No related branches found
No related tags found
No related merge requests found
...@@ -130,3 +130,4 @@ read.ml <- function(path) { ...@@ -130,3 +130,4 @@ read.ml <- function(path) {
stop("Unsupported model: ", jobj) stop("Unsupported model: ", jobj)
} }
} }
...@@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path ...@@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer}
import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.VectorUDT import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol}
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
...@@ -210,8 +210,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) ...@@ -210,8 +210,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
// First we index each string column referenced by the input terms. // First we index each string column referenced by the input terms.
val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term => val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term =>
dataset.schema(term) match { dataset.schema(term).dataType match {
case column if column.dataType == StringType => case _: StringType =>
val indexCol = tmpColumn("stridx") val indexCol = tmpColumn("stridx")
encoderStages += new StringIndexer() encoderStages += new StringIndexer()
.setInputCol(term) .setInputCol(term)
...@@ -220,6 +220,18 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) ...@@ -220,6 +220,18 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String)
.setHandleInvalid($(handleInvalid)) .setHandleInvalid($(handleInvalid))
prefixesToRewrite(indexCol + "_") = term + "_" prefixesToRewrite(indexCol + "_") = term + "_"
(term, indexCol) (term, indexCol)
case _: VectorUDT =>
val group = AttributeGroup.fromStructField(dataset.schema(term))
val size = if (group.size < 0) {
dataset.select(term).first().getAs[Vector](0).size
} else {
group.size
}
encoderStages += new VectorSizeHint(uid)
.setHandleInvalid("optimistic")
.setInputCol(term)
.setSize(size)
(term, term)
case _ => case _ =>
(term, term) (term, term)
} }
......
...@@ -17,15 +17,15 @@ ...@@ -17,15 +17,15 @@
package org.apache.spark.ml.feature package org.apache.spark.ml.feature
import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.SparkException
import org.apache.spark.ml.attribute._ import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Encoder, Row}
import org.apache.spark.sql.types.DoubleType import org.apache.spark.sql.types.DoubleType
class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { class RFormulaSuite extends MLTest with DefaultReadWriteTest {
import testImplicits._ import testImplicits._
...@@ -548,4 +548,31 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul ...@@ -548,4 +548,31 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(result3.collect() === expected3.collect()) assert(result3.collect() === expected3.collect())
assert(result4.collect() === expected4.collect()) assert(result4.collect() === expected4.collect())
} }
test("Use Vectors as inputs to formula.") {
val original = Seq(
(1, 4, Vectors.dense(0.0, 0.0, 4.0)),
(2, 4, Vectors.dense(1.0, 0.0, 4.0)),
(3, 5, Vectors.dense(1.0, 0.0, 5.0)),
(4, 5, Vectors.dense(0.0, 1.0, 5.0))
).toDF("id", "a", "b")
val formula = new RFormula().setFormula("id ~ a + b")
val (first +: rest) = Seq("id", "a", "b", "features", "label")
testTransformer[(Int, Int, Vector)](original, formula.fit(original), first, rest: _*) {
case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) =>
assert(label === id)
assert(features.toArray === a +: b.toArray)
}
val group = new AttributeGroup("b", 3)
val vectorColWithMetadata = original("b").as("b", group.toMetadata())
val dfWithMetadata = original.withColumn("b", vectorColWithMetadata)
val model = formula.fit(dfWithMetadata)
// model should work even when applied to dataframe without metadata.
testTransformer[(Int, Int, Vector)](original, model, first, rest: _*) {
case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) =>
assert(label === id)
assert(features.toArray === a +: b.toArray)
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment