Skip to content
Snippets Groups Projects
Commit b1835d72 authored by Grzegorz Chilkiewicz's avatar Grzegorz Chilkiewicz Committed by Joseph K. Bradley
Browse files

[SPARK-12711][ML] ML StopWordsRemover does not protect itself from column name duplication

Fixes problem and verifies fix by test suite.
Also - adds optional parameter: nullable (Boolean) to: SchemaUtils.appendColumn
and deduplicates SchemaUtils.appendColumn functions.

Author: Grzegorz Chilkiewicz <grzegorz.chilkiewicz@codilime.com>

Closes #10741 from grzegorz-chilkiewicz/master.
parent 358300c7
No related branches found
No related tags found
No related merge requests found
......@@ -149,9 +149,7 @@ class StopWordsRemover(override val uid: String)
val inputType = schema($(inputCol)).dataType
require(inputType.sameType(ArrayType(StringType)),
s"Input type must be ArrayType(StringType) but got $inputType.")
val outputFields = schema.fields :+
StructField($(outputCol), inputType, schema($(inputCol)).nullable)
StructType(outputFields)
SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable)
}
override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra)
......
......@@ -71,12 +71,10 @@ private[spark] object SchemaUtils {
def appendColumn(
schema: StructType,
colName: String,
dataType: DataType): StructType = {
dataType: DataType,
nullable: Boolean = false): StructType = {
if (colName.isEmpty) return schema
val fieldNames = schema.fieldNames
require(!fieldNames.contains(colName), s"Column $colName already exists.")
val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false)
StructType(outputFields)
appendColumn(schema, StructField(colName, dataType, nullable))
}
/**
......
......@@ -89,4 +89,19 @@ class StopWordsRemoverSuite
.setCaseSensitive(true)
testDefaultReadWrite(t)
}
test("StopWordsRemover output column already exists") {
val outputCol = "expected"
val remover = new StopWordsRemover()
.setInputCol("raw")
.setOutputCol(outputCol)
val dataSet = sqlContext.createDataFrame(Seq(
(Seq("The", "the", "swift"), Seq("swift"))
)).toDF("raw", outputCol)
val thrown = intercept[IllegalArgumentException] {
testStopWordsRemover(remover, dataSet)
}
assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.")
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment