diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 912bd95a2ec702ac85a1d25c41b8dadf733702e9..555f1130e46a0cf4c27a1924c3dae5fcc1a58dcd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -150,6 +150,7 @@ class StringIndexerModel ( "Skip StringIndexerModel.") return dataset } + validateAndTransformSchema(dataset.schema) val indexer = udf { label: String => if (labelToIndex.contains(label)) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 5d199ca9b51b17ab6ddde92cf9d8e67946f61bb9..0dbaed25229575c390f42e4927512a08d8dcbc80 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -118,6 +118,17 @@ class StringIndexerSuite assert(indexerModel.transform(df).eq(df)) } + test("StringIndexerModel can't overwrite output column") { + val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + val indexer = new StringIndexer() + .setInputCol("input") + .setOutputCol("output") + .fit(df) + intercept[IllegalArgumentException] { + indexer.transform(df) + } + } + test("StringIndexer read/write") { val t = new StringIndexer() .setInputCol("myInputCol")