Skip to content
Snippets Groups Projects
Commit 6cbde337 authored by Yanbo Liang's avatar Yanbo Liang Committed by Sean Owen
Browse files

[SPARK-16750][FOLLOW-UP][ML] Add transformSchema for...

[SPARK-16750][FOLLOW-UP][ML] Add transformSchema for StringIndexer/VectorAssembler and fix failed tests.

## What changes were proposed in this pull request?
This is follow-up for #14378. When we add ```transformSchema``` for all estimators and transformers, I found there are tests failed for ```StringIndexer``` and ```VectorAssembler```. So I moved these parts of work separately in this PR, to make it more clear to review.
The corresponding tests should throw ```IllegalArgumentException``` at schema validation period after we add ```transformSchema```. It's efficient that to throw exception at the start of ```fit``` or ```transform``` rather than during the process.

## How was this patch tested?
Modified unit tests.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #14455 from yanboliang/transformSchema.
parent 1f96c97f
No related branches found
No related tags found
No related merge requests found
......@@ -85,6 +85,7 @@ class StringIndexer @Since("1.4.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): StringIndexerModel = {
transformSchema(dataset.schema, logging = true)
val counts = dataset.select(col($(inputCol)).cast(StringType))
.rdd
.map(_.getString(0))
......@@ -160,7 +161,7 @@ class StringIndexerModel (
"Skip StringIndexerModel.")
return dataset.toDF
}
validateAndTransformSchema(dataset.schema)
transformSchema(dataset.schema, logging = true)
val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
......@@ -305,6 +306,7 @@ class IndexToString private[ml] (@Since("1.5.0") override val uid: String)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val inputColSchema = dataset.schema($(inputCol))
// If the labels array is empty use column metadata
val values = if (!isDefined(labels) || $(labels).isEmpty) {
......
......@@ -51,6 +51,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
// Schema transformation.
val schema = dataset.schema
lazy val first = dataset.toDF.first()
......
......@@ -120,12 +120,20 @@ class StringIndexerSuite
test("StringIndexerModel can't overwrite output column") {
val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output")
intercept[IllegalArgumentException] {
new StringIndexer()
.setInputCol("input")
.setOutputCol("output")
.fit(df)
}
val indexer = new StringIndexer()
.setInputCol("input")
.setOutputCol("output")
.setOutputCol("indexedInput")
.fit(df)
intercept[IllegalArgumentException] {
indexer.transform(df)
indexer.setOutputCol("output").transform(df)
}
}
......
......@@ -74,10 +74,10 @@ class VectorAssemblerSuite
val assembler = new VectorAssembler()
.setInputCols(Array("a", "b", "c"))
.setOutputCol("features")
val thrown = intercept[SparkException] {
val thrown = intercept[IllegalArgumentException] {
assembler.transform(df)
}
assert(thrown.getMessage contains "VectorAssembler does not support the StringType type")
assert(thrown.getMessage contains "Data type StringType is not supported")
}
test("ML attributes") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment